using System.Linq; using System.Text; using Fantasy.SourceGenerator.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; #pragma warning disable CS0649 // Field is never assigned to, and will always have its default value #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. namespace Fantasy.SourceGenerator.Generators; [Generator] internal partial class NetworkProtocolGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { var networkProtocols = context.SyntaxProvider .CreateSyntaxProvider( predicate: static (node, _) => IsNetworkProtocolClass(node), transform: static (ctx, _) => GetNetworkProtocolInfo(ctx)) .Where(static info => info != null) .Collect(); var compilationAndTypes = context.CompilationProvider.Combine(networkProtocols); context.RegisterSourceOutput(compilationAndTypes, static (spc, source) => { if (!CompilationHelper.HasFantasyDefine(source.Left)) { return; } if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null) { return; } GenerateCode(spc, source.Left, source.Right!); }); } #region GenerateCode private static void GenerateCode( SourceProductionContext context, Compilation compilation, IEnumerable networkProtocolTypeInfos) { var networkProtocolTypeInfoList = networkProtocolTypeInfos.ToList(); // 获取当前程序集名称(仅用于注释) var assemblyName = compilation.AssemblyName ?? "Unknown"; // 生成网络网络协议类型注册的类。 GenerateNetworkProtocolTypesCode(context, assemblyName, networkProtocolTypeInfoList); // 生成OpCode辅助方法。 GenerateNetworkProtocolOpCodeResolverCode(context, assemblyName, networkProtocolTypeInfoList); // 生成Request消息的ResponseType辅助方法。 GenerateNetworkProtocolResponseTypesResolverCode(context, assemblyName, networkProtocolTypeInfoList); } private static void GenerateNetworkProtocolTypesCode( SourceProductionContext context, string assemblyName, List networkProtocolTypeInfoList) { var builder = new SourceCodeBuilder(); // 添加文件头 builder.AppendLine(GeneratorConstants.AutoGeneratedHeader); // 添加 using builder.AddUsings( "System", "System.Collections.Generic", "Fantasy.Assembly", "Fantasy.DataStructure.Collection" ); // 开始命名空间(固定使用 Fantasy.Generated) builder.BeginNamespace("Fantasy.Generated"); // 开始类定义(实现 IEventSystemRegistrar 接口) builder.AddXmlComment($"Auto-generated NetworkProtocol registration class for {assemblyName}"); builder.BeginClass("NetworkProtocolRegistrar", "internal sealed", "INetworkProtocolRegistrar"); builder.BeginMethod("public List GetNetworkProtocolTypes()"); if (networkProtocolTypeInfoList.Any()) { builder.AppendLine($"return new List({networkProtocolTypeInfoList.Count})"); builder.AppendLine("{"); builder.Indent(); foreach (var system in networkProtocolTypeInfoList) { builder.AppendLine($"typeof({system.FullName}),"); } builder.Unindent(); builder.AppendLine("};"); builder.AppendLine(); } else { builder.AppendLine($"return new List();"); } builder.EndMethod(); builder.AppendLine(); // 结束类和命名空间 builder.EndClass(); builder.EndNamespace(); // 输出源代码 context.AddSource("NetworkProtocolRegistrar.g.cs", builder.ToString()); } private static void GenerateNetworkProtocolOpCodeResolverCode( SourceProductionContext context, string assemblyName, List networkProtocolTypeInfoList) { var routeTypeInfos = networkProtocolTypeInfoList.Where(d => d.RouteType.HasValue).ToList(); var builder = new SourceCodeBuilder(); // 添加文件头 builder.AppendLine(GeneratorConstants.AutoGeneratedHeader); // 添加 using builder.AddUsings( "System", "System.Collections.Generic", "Fantasy.Assembly", "Fantasy.DataStructure.Collection", "System.Runtime.CompilerServices" ); builder.AppendLine(); // 开始命名空间(固定使用 Fantasy.Generated) builder.BeginNamespace("Fantasy.Generated"); // 开始类定义(实现 INetworkProtocolOpCodeResolver 接口) builder.AddXmlComment($"Auto-generated NetworkProtocolOpCodeResolverRegistrar class for {assemblyName}"); builder.BeginClass("NetworkProtocolOpCodeResolverRegistrar", "internal sealed", "INetworkProtocolOpCodeResolver"); builder.AddXmlComment($"GetOpCodeCount"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int GetOpCodeCount()"); builder.AppendLine($"return {networkProtocolTypeInfoList.Count};"); builder.EndMethod(); builder.AddXmlComment($"GetCustomRouteTypeCount"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int GetCustomRouteTypeCount()"); builder.AppendLine($"return {routeTypeInfos.Count};"); builder.EndMethod(); // 开始定义GetOpCodeType方法 builder.AddXmlComment($"GetOpCodeType"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public Type GetOpCodeType(uint opCode)"); if (networkProtocolTypeInfoList.Any()) { builder.AppendLine("switch (opCode)"); builder.AppendLine("{"); builder.Indent(); foreach (var networkProtocolTypeInfo in networkProtocolTypeInfoList) { builder.AppendLine($"case {networkProtocolTypeInfo.OpCode}:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return typeof({networkProtocolTypeInfo.FullName});"); builder.Unindent(); builder.AppendLine("}"); } builder.AppendLine("default:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return null!;"); builder.Unindent(); builder.AppendLine("}"); builder.Unindent(); builder.AppendLine("}"); builder.AppendLine(); } else { builder.AppendLine($"return null!;"); } builder.EndMethod(); // 开始定义GetRouteType方法 builder.AddXmlComment($"CustomRouteType"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int? GetCustomRouteType(uint opCode)"); if (routeTypeInfos.Any()) { builder.AppendLine("switch (opCode)"); builder.AppendLine("{"); builder.Indent(); foreach (var routeTypeInfo in routeTypeInfos) { builder.AppendLine($"case {routeTypeInfo.OpCode}:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return {routeTypeInfo.RouteType};"); builder.Unindent(); builder.AppendLine("}"); } builder.AppendLine("default:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return null;"); builder.Unindent(); builder.AppendLine("}"); builder.Unindent(); builder.AppendLine("}"); builder.AppendLine(); } else { builder.AppendLine($"return null;"); } builder.EndMethod(); builder.AppendLine(); // 结束类和命名空间 builder.EndClass(); builder.EndNamespace(); // 输出源代码 context.AddSource("NetworkProtocolOpCodeResolverRegistrar.g.cs", builder.ToString()); } private static void GenerateNetworkProtocolResponseTypesResolverCode( SourceProductionContext context, string assemblyName, List networkProtocolTypeInfoList) { var requestList = networkProtocolTypeInfoList.Where(d => d.ResponseType != null).ToList(); var builder = new SourceCodeBuilder(); // 添加文件头 builder.AppendLine(GeneratorConstants.AutoGeneratedHeader); // 添加 using builder.AddUsings( "System", "System.Collections.Generic", "Fantasy.Assembly", "Fantasy.DataStructure.Collection", "System.Runtime.CompilerServices" ); builder.AppendLine(); // 开始命名空间(固定使用 Fantasy.Generated) builder.BeginNamespace("Fantasy.Generated"); // 开始类定义(实现 IEventSystemRegistrar 接口) builder.AddXmlComment($"Auto-generated NetworkProtocolResponseTypeResolverRegistrar class for {assemblyName}"); builder.BeginClass("NetworkProtocolResponseTypeResolverRegistrar", "internal sealed", "INetworkProtocolResponseTypeResolver"); builder.AddXmlComment($"GetOpCodeCount"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int GetRequestCount()"); builder.AppendLine($"return {requestList.Count};"); builder.EndMethod(); // 开始定义GetOpCodeType方法 builder.AddXmlComment($"GetOpCodeType"); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public Type GetResponseType(uint opCode)"); if (requestList.Any()) { builder.AppendLine("switch (opCode)"); builder.AppendLine("{"); builder.Indent(); foreach (var request in requestList) { builder.AppendLine($"case {request.OpCode}:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return typeof({request.ResponseType});"); builder.Unindent(); builder.AppendLine("}"); } builder.AppendLine("default:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return null!;"); builder.Unindent(); builder.AppendLine("}"); builder.Unindent(); builder.AppendLine("}"); builder.AppendLine(); } else { builder.AppendLine($"return null!;"); } builder.EndMethod(); builder.AppendLine(); // 结束类和命名空间 builder.EndClass(); builder.EndNamespace(); // 输出源代码 context.AddSource("NetworkProtocolResponseTypeResolverRegistrar.g.cs", builder.ToString()); } #endregion private static NetworkProtocolTypeInfo? GetNetworkProtocolInfo(GeneratorSyntaxContext context) { var classDecl = (ClassDeclarationSyntax)context.Node; var symbol = context.SemanticModel.GetDeclaredSymbol(classDecl) as INamedTypeSymbol; if (symbol == null || !symbol.IsInstantiable()) { return null; } var baseType = symbol.BaseType; if (baseType == null) { return null; } if (baseType.ToDisplayString() != "Fantasy.Network.Interface.AMessage") { return null; } // 获取 OpCode 方法的值 uint? opCodeValue = null; var opCodeMethod = symbol.GetMembers("OpCode").OfType().FirstOrDefault(); var opCodeSyntax = opCodeMethod?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax; if (opCodeSyntax?.Body != null) { var returnStatement = opCodeSyntax.Body.DescendantNodes() .OfType() .FirstOrDefault(); if (returnStatement?.Expression != null) { var constantValue = context.SemanticModel.GetConstantValue(returnStatement.Expression); if (constantValue.HasValue && constantValue.Value is uint uintValue) { opCodeValue = uintValue; } } } if (!opCodeValue.HasValue) { return null; } // 获取 ResponseType 属性及其类型 string? responseTypeName = null; var responseTypeProperty = symbol.GetMembers("ResponseType").OfType().FirstOrDefault(); if (responseTypeProperty != null) { // 获取 ResponseType 属性的类型(例如 G2C_TestResponse) responseTypeName = responseTypeProperty.Type.GetFullName(); } // 获取 RouteType 属性的值 int? routeTypeValue = null; var routeTypeProperty = symbol.GetMembers("RouteType").OfType().FirstOrDefault(); var routeTypeSyntax = routeTypeProperty?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as PropertyDeclarationSyntax; if (routeTypeSyntax?.ExpressionBody != null) { var constantValue = context.SemanticModel.GetConstantValue(routeTypeSyntax.ExpressionBody.Expression); if (constantValue.HasValue && constantValue.Value is int intValue) { routeTypeValue = intValue; } } return new NetworkProtocolTypeInfo( symbol.GetFullName(), opCodeValue.Value, responseTypeName, routeTypeValue ); } private static bool IsNetworkProtocolClass(SyntaxNode node) { if (node is not ClassDeclarationSyntax classDecl) { return false; } if (classDecl.BaseList == null) { return false; } foreach (var baseTypeSyntax in classDecl.BaseList.Types) { if (baseTypeSyntax.Type.ToString().Contains("AMessage")) { return true; } } return false; } private sealed record NetworkProtocolTypeInfo( string FullName, uint OpCode, string? ResponseType, int? RouteType); }