using Fantasy.SourceGenerator.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Fantasy.SourceGenerator.Generators { [Generator] public sealed class MessageHandlerGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { // 查找所有实现了消息相关接口的类 var messageTypes = context.SyntaxProvider .CreateSyntaxProvider( predicate: static (node, _) => IsMessageHandlerClass(node), transform: static (ctx, _) => GetMessageTypeInfo(ctx)) .Where(static info => info != null) .Collect(); // 组合编译信息和找到的类型 var compilationAndTypes = context.CompilationProvider.Combine(messageTypes); // 注册源代码输出 context.RegisterSourceOutput(compilationAndTypes, static (spc, source) => { // 检查1: 是否定义了 FANTASY_NET 或 FANTASY_UNITY 预编译符号 if (!CompilationHelper.HasFantasyDefine(source.Left)) { return; } // 检查2: 是否引用了 Fantasy 框架的核心类型 if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null) { return; } GenerateRegistrationCode(spc, source.Left, source.Right!); }); } private static void GenerateRegistrationCode( SourceProductionContext context, Compilation compilation, IEnumerable messageHandlerInfos) { var messageHandlers = new List(); var routeMessageHandlers = new List(); foreach (var messageHandlerInfo in messageHandlerInfos) { switch (messageHandlerInfo.HandlerType) { case HandlerType.MessageHandler: { messageHandlers.Add(messageHandlerInfo); break; } case HandlerType.RouteMessageHandler: { routeMessageHandlers.Add(messageHandlerInfo); break; } } } var assemblyName = compilation.AssemblyName ?? "Unknown"; var builder = new SourceCodeBuilder(); builder.AppendLine(GeneratorConstants.AutoGeneratedHeader); builder.AddUsings( "System", "System.Collections.Generic", "Fantasy.Assembly", "Fantasy.DataStructure.Dictionary", "Fantasy.Network.Interface", "Fantasy.Network", "Fantasy.Entitas", "Fantasy.Async", "System.Runtime.CompilerServices" ); builder.AppendLine(); builder.BeginNamespace("Fantasy.Generated"); builder.AddXmlComment($"Auto-generated message handler registration class for {assemblyName}"); builder.BeginClass("MessageHandlerResolverRegistrar", "internal sealed", "IMessageHandlerResolver"); // 生成字段用于存储已注册的实例(用于 UnRegister) GenerateFields(builder, messageHandlers, routeMessageHandlers); // 生成 Register 方法 GenerateRegistrationCode(builder, messageHandlers, routeMessageHandlers); // 结束类和命名空间 builder.EndClass(); builder.EndNamespace(); // 输出源代码 context.AddSource("MessageHandlerResolverRegistrar.g.cs", builder.ToString()); } private static void GenerateFields(SourceCodeBuilder builder, List messageHandlers, List routeMessageHandlers) { foreach (var messageHandlerInfo in messageHandlers) { builder.AppendLine($"private Func message_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;"); } foreach (var messageHandlerInfo in routeMessageHandlers) { builder.AppendLine($"private Func routeMessage_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;"); } builder.AppendLine(); } private static void GenerateRegistrationCode(SourceCodeBuilder builder, List messageHandlers, List routeMessageHandlers) { builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int GetMessageHandlerCount()"); builder.AppendLine($"return {messageHandlers.Count};"); builder.EndMethod(); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public int GetRouteMessageHandlerCount()"); builder.AppendLine($"return {routeMessageHandlers.Count};"); builder.EndMethod(); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public bool MessageHandler(Session session, uint rpcId, uint protocolCode, object message)"); if (messageHandlers.Any()) { builder.AppendLine("switch (protocolCode)"); builder.AppendLine("{"); builder.Indent(); foreach (var messageHandlerInfo in messageHandlers) { builder.AppendLine($"case {messageHandlerInfo.OpCode}:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"message_{messageHandlerInfo.TypeName}(session, rpcId, protocolCode, message).Coroutine();"); builder.AppendLine($"return true;"); builder.Unindent(); builder.AppendLine("}"); } builder.AppendLine("default:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"return false;"); builder.Unindent(); builder.AppendLine("}"); builder.Unindent(); builder.AppendLine("}"); } else { builder.AppendLine($"return false;"); } builder.EndMethod(); builder.AppendLine("#if FANTASY_NET", false); builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]"); builder.BeginMethod("public async FTask RouteMessageHandler(Session session, Entity entity, uint rpcId, uint protocolCode, object message)"); if (routeMessageHandlers.Any()) { builder.AppendLine("switch (protocolCode)"); builder.AppendLine("{"); builder.Indent(); foreach (var routeMessageHandler in routeMessageHandlers) { builder.AppendLine($"case {routeMessageHandler.OpCode}:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"await routeMessage_{routeMessageHandler.TypeName}(session, entity, rpcId, message);"); builder.AppendLine($"return true;"); builder.Unindent(); builder.AppendLine("}"); } builder.AppendLine("default:"); builder.AppendLine("{"); builder.Indent(); builder.AppendLine($"await FTask.CompletedTask;"); builder.AppendLine($"return false;"); builder.Unindent(); builder.AppendLine("}"); builder.Unindent(); builder.AppendLine("}"); } else { builder.AppendLine($"await FTask.CompletedTask;"); builder.AppendLine($"return false;"); } builder.EndMethod(); builder.AppendLine("#endif", false); } private static bool IsMessageHandlerClass(SyntaxNode node) { if (node is not ClassDeclarationSyntax classDecl) { return false; } if (classDecl.BaseList == null || !classDecl.BaseList.Types.Any()) { return false; } foreach (var baseType in classDecl.BaseList.Types) { var typeName = baseType.Type.ToString(); if (typeName.Contains("IMessageHandler") || typeName.Contains("IRouteMessageHandler") || typeName.Contains("Message<") || typeName.Contains("MessageRPC<") || typeName.Contains("Route<") || typeName.Contains("RouteRPC<") || typeName.Contains("Addressable<") || typeName.Contains("AddressableRPC<") || typeName.Contains("Roaming<") || typeName.Contains("RoamingRPC<")) { return true; } } return false; } private static MessageHandlerInfo? GetMessageTypeInfo(GeneratorSyntaxContext context) { var classDecl = (ClassDeclarationSyntax)context.Node; if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol || !symbol.IsInstantiable()) { return null; } var baseType = symbol.BaseType; if (baseType is not { IsGenericType: true } || baseType.TypeArguments.Length <= 0) { return null; } var baseTypeName = baseType.OriginalDefinition.ToDisplayString(); switch (baseTypeName) { case "Fantasy.Network.Interface.Message": case "Fantasy.Network.Interface.MessageRPC": { return new MessageHandlerInfo( HandlerType.MessageHandler, symbol.GetFullName(), symbol.Name, GetOpCode(context, baseType, 0)); } case "Fantasy.Network.Interface.Route": case "Fantasy.Network.Interface.RouteRPC": case "Fantasy.Network.Interface.Addressable": case "Fantasy.Network.Interface.AddressableRPC": case "Fantasy.Network.Interface.Roaming": case "Fantasy.Network.Interface.RoamingRPC": { return new MessageHandlerInfo( HandlerType.RouteMessageHandler, symbol.GetFullName(), symbol.Name, GetOpCode(context, baseType, 1)); } } return null; } private static uint? GetOpCode(GeneratorSyntaxContext context, INamedTypeSymbol baseType, int index) { if (baseType.TypeArguments.Length <= index) { return null; } var messageType = (INamedTypeSymbol)baseType.TypeArguments[index]; var messageName = messageType.Name; var compilation = context.SemanticModel.Compilation; // 策略1:从消息类型所在程序集中搜索 OpCode 类 var messageAssembly = messageType.ContainingAssembly; var namespaceName = messageType.ContainingNamespace.ToDisplayString(); // 遍历程序集中的所有类型,查找 OuterOpcode 或 InnerOpcode var opCodeTypeNames = new[] { "OuterOpcode", "InnerOpcode" }; foreach (var opCodeTypeName in opCodeTypeNames) { var opCodeType = FindTypeInAssembly(messageAssembly.GlobalNamespace, namespaceName, opCodeTypeName); if (opCodeType != null) { var opCodeField = opCodeType.GetMembers(messageName).OfType().FirstOrDefault(); if (opCodeField != null && opCodeField.IsConst && opCodeField.ConstantValue is uint constValue) { return constValue; } } } // 策略2:如果策略1失败,尝试从 OpCode() 方法的语法树中解析(仅适用于同项目中的消息) var opCodeMethod = messageType.GetMembers("OpCode").OfType().FirstOrDefault(); if (opCodeMethod != null) { var opCodeSyntax = opCodeMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax; if (opCodeSyntax?.Body != null) { var returnStatement = opCodeSyntax.Body.DescendantNodes() .OfType() .FirstOrDefault(); if (returnStatement?.Expression != null) { var syntaxTree = opCodeSyntax.SyntaxTree; if (compilation.ContainsSyntaxTree(syntaxTree)) { var semanticModel = compilation.GetSemanticModel(syntaxTree); // 尝试符号解析 var symbolInfo = semanticModel.GetSymbolInfo(returnStatement.Expression); if (symbolInfo.Symbol is IFieldSymbol fieldSymbol && fieldSymbol.IsConst && fieldSymbol.ConstantValue is uint constValue2) { return constValue2; } // 尝试常量值解析 var constantValue = semanticModel.GetConstantValue(returnStatement.Expression); if (constantValue.HasValue && constantValue.Value is uint uintValue) { return uintValue; } } } } } return null; } // 辅助方法:在程序集的命名空间中递归查找指定类型 private static INamedTypeSymbol? FindTypeInAssembly(INamespaceSymbol namespaceSymbol, string targetNamespace, string typeName) { // 如果当前命名空间匹配目标命名空间,查找类型 if (namespaceSymbol.ToDisplayString() == targetNamespace) { var type = namespaceSymbol.GetTypeMembers(typeName).FirstOrDefault(); if (type != null) { return type; } } // 递归搜索子命名空间 foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers()) { var result = FindTypeInAssembly(childNamespace, targetNamespace, typeName); if (result != null) { return result; } } return null; } private enum HandlerType { None, MessageHandler, RouteMessageHandler } private sealed record MessageHandlerInfo( HandlerType HandlerType, string TypeFullName, string TypeName, uint? OpCode); } }