Files
Fishing2Server/Fantasy/Fantasy.Net/Fantasy.SourceGenerator/Generators/MessageHandlerGenerator.cs
2025-10-29 17:59:43 +08:00

379 lines
16 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using Fantasy.SourceGenerator.Common;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Fantasy.SourceGenerator.Generators
{
[Generator]
public sealed class MessageHandlerGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// 查找所有实现了消息相关接口的类
var messageTypes = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (node, _) => IsMessageHandlerClass(node),
transform: static (ctx, _) => GetMessageTypeInfo(ctx))
.Where(static info => info != null)
.Collect();
// 组合编译信息和找到的类型
var compilationAndTypes = context.CompilationProvider.Combine(messageTypes);
// 注册源代码输出
context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
{
// 检查1: 是否定义了 FANTASY_NET 或 FANTASY_UNITY 预编译符号
if (!CompilationHelper.HasFantasyDefine(source.Left))
{
return;
}
// 检查2: 是否引用了 Fantasy 框架的核心类型
if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
{
return;
}
GenerateRegistrationCode(spc, source.Left, source.Right!);
});
}
private static void GenerateRegistrationCode(
SourceProductionContext context,
Compilation compilation,
IEnumerable<MessageHandlerInfo> messageHandlerInfos)
{
var messageHandlers = new List<MessageHandlerInfo>();
var routeMessageHandlers = new List<MessageHandlerInfo>();
foreach (var messageHandlerInfo in messageHandlerInfos)
{
switch (messageHandlerInfo.HandlerType)
{
case HandlerType.MessageHandler:
{
messageHandlers.Add(messageHandlerInfo);
break;
}
case HandlerType.RouteMessageHandler:
{
routeMessageHandlers.Add(messageHandlerInfo);
break;
}
}
}
var assemblyName = compilation.AssemblyName ?? "Unknown";
var builder = new SourceCodeBuilder();
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
builder.AddUsings(
"System",
"System.Collections.Generic",
"Fantasy.Assembly",
"Fantasy.DataStructure.Dictionary",
"Fantasy.Network.Interface",
"Fantasy.Network",
"Fantasy.Entitas",
"Fantasy.Async",
"System.Runtime.CompilerServices"
);
builder.AppendLine();
builder.BeginNamespace("Fantasy.Generated");
builder.AddXmlComment($"Auto-generated message handler registration class for {assemblyName}");
builder.BeginClass("MessageHandlerResolverRegistrar", "internal sealed", "IMessageHandlerResolver");
// 生成字段用于存储已注册的实例(用于 UnRegister
GenerateFields(builder, messageHandlers, routeMessageHandlers);
// 生成 Register 方法
GenerateRegistrationCode(builder, messageHandlers, routeMessageHandlers);
// 结束类和命名空间
builder.EndClass();
builder.EndNamespace();
// 输出源代码
context.AddSource("MessageHandlerResolverRegistrar.g.cs", builder.ToString());
}
private static void GenerateFields(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
{
foreach (var messageHandlerInfo in messageHandlers)
{
builder.AppendLine($"private Func<Session, uint, uint, object, FTask> message_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
}
foreach (var messageHandlerInfo in routeMessageHandlers)
{
builder.AppendLine($"private Func<Session, Entity, uint, object, FTask> routeMessage_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
}
builder.AppendLine();
}
private static void GenerateRegistrationCode(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
{
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int GetMessageHandlerCount()");
builder.AppendLine($"return {messageHandlers.Count};");
builder.EndMethod();
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int GetRouteMessageHandlerCount()");
builder.AppendLine($"return {routeMessageHandlers.Count};");
builder.EndMethod();
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public bool MessageHandler(Session session, uint rpcId, uint protocolCode, object message)");
if (messageHandlers.Any())
{
builder.AppendLine("switch (protocolCode)");
builder.AppendLine("{");
builder.Indent();
foreach (var messageHandlerInfo in messageHandlers)
{
builder.AppendLine($"case {messageHandlerInfo.OpCode}:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"message_{messageHandlerInfo.TypeName}(session, rpcId, protocolCode, message).Coroutine();");
builder.AppendLine($"return true;");
builder.Unindent();
builder.AppendLine("}");
}
builder.AppendLine("default:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return false;");
builder.Unindent();
builder.AppendLine("}");
builder.Unindent();
builder.AppendLine("}");
}
else
{
builder.AppendLine($"return false;");
}
builder.EndMethod();
builder.AppendLine("#if FANTASY_NET", false);
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public async FTask<bool> RouteMessageHandler(Session session, Entity entity, uint rpcId, uint protocolCode, object message)");
if (routeMessageHandlers.Any())
{
builder.AppendLine("switch (protocolCode)");
builder.AppendLine("{");
builder.Indent();
foreach (var routeMessageHandler in routeMessageHandlers)
{
builder.AppendLine($"case {routeMessageHandler.OpCode}:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"await routeMessage_{routeMessageHandler.TypeName}(session, entity, rpcId, message);");
builder.AppendLine($"return true;");
builder.Unindent();
builder.AppendLine("}");
}
builder.AppendLine("default:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"await FTask.CompletedTask;");
builder.AppendLine($"return false;");
builder.Unindent();
builder.AppendLine("}");
builder.Unindent();
builder.AppendLine("}");
}
else
{
builder.AppendLine($"await FTask.CompletedTask;");
builder.AppendLine($"return false;");
}
builder.EndMethod();
builder.AppendLine("#endif", false);
}
private static bool IsMessageHandlerClass(SyntaxNode node)
{
if (node is not ClassDeclarationSyntax classDecl)
{
return false;
}
if (classDecl.BaseList == null || !classDecl.BaseList.Types.Any())
{
return false;
}
foreach (var baseType in classDecl.BaseList.Types)
{
var typeName = baseType.Type.ToString();
if (typeName.Contains("IMessageHandler") ||
typeName.Contains("IRouteMessageHandler") ||
typeName.Contains("Message<") ||
typeName.Contains("MessageRPC<") ||
typeName.Contains("Route<") ||
typeName.Contains("RouteRPC<") ||
typeName.Contains("Addressable<") ||
typeName.Contains("AddressableRPC<") ||
typeName.Contains("Roaming<") ||
typeName.Contains("RoamingRPC<"))
{
return true;
}
}
return false;
}
private static MessageHandlerInfo? GetMessageTypeInfo(GeneratorSyntaxContext context)
{
var classDecl = (ClassDeclarationSyntax)context.Node;
if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol ||
!symbol.IsInstantiable())
{
return null;
}
var baseType = symbol.BaseType;
if (baseType is not { IsGenericType: true } || baseType.TypeArguments.Length <= 0)
{
return null;
}
var baseTypeName = baseType.OriginalDefinition.ToDisplayString();
switch (baseTypeName)
{
case "Fantasy.Network.Interface.Message<T>":
case "Fantasy.Network.Interface.MessageRPC<TRequest, TResponse>":
{
return new MessageHandlerInfo(
HandlerType.MessageHandler,
symbol.GetFullName(),
symbol.Name,
GetOpCode(context, baseType, 0));
}
case "Fantasy.Network.Interface.Route<TEntity, TMessage>":
case "Fantasy.Network.Interface.RouteRPC<TEntity, TRouteRequest, TRouteResponse>":
case "Fantasy.Network.Interface.Addressable<TEntity, TMessage>":
case "Fantasy.Network.Interface.AddressableRPC<TEntity, TRouteRequest, TRouteResponse>":
case "Fantasy.Network.Interface.Roaming<TEntity, TMessage>":
case "Fantasy.Network.Interface.RoamingRPC<TEntity, TRouteRequest, TRouteResponse>":
{
return new MessageHandlerInfo(
HandlerType.RouteMessageHandler,
symbol.GetFullName(),
symbol.Name,
GetOpCode(context, baseType, 1));
}
}
return null;
}
private static uint? GetOpCode(GeneratorSyntaxContext context, INamedTypeSymbol baseType, int index)
{
if (baseType.TypeArguments.Length <= index)
{
return null;
}
var messageType = (INamedTypeSymbol)baseType.TypeArguments[index];
var messageName = messageType.Name;
var compilation = context.SemanticModel.Compilation;
// 策略1从消息类型所在程序集中搜索 OpCode 类
var messageAssembly = messageType.ContainingAssembly;
var namespaceName = messageType.ContainingNamespace.ToDisplayString();
// 遍历程序集中的所有类型,查找 OuterOpcode 或 InnerOpcode
var opCodeTypeNames = new[] { "OuterOpcode", "InnerOpcode" };
foreach (var opCodeTypeName in opCodeTypeNames)
{
var opCodeType = FindTypeInAssembly(messageAssembly.GlobalNamespace, namespaceName, opCodeTypeName);
if (opCodeType != null)
{
var opCodeField = opCodeType.GetMembers(messageName).OfType<IFieldSymbol>().FirstOrDefault();
if (opCodeField != null && opCodeField.IsConst && opCodeField.ConstantValue is uint constValue)
{
return constValue;
}
}
}
// 策略2如果策略1失败尝试从 OpCode() 方法的语法树中解析(仅适用于同项目中的消息)
var opCodeMethod = messageType.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
if (opCodeMethod != null)
{
var opCodeSyntax = opCodeMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
if (opCodeSyntax?.Body != null)
{
var returnStatement = opCodeSyntax.Body.DescendantNodes()
.OfType<ReturnStatementSyntax>()
.FirstOrDefault();
if (returnStatement?.Expression != null)
{
var syntaxTree = opCodeSyntax.SyntaxTree;
if (compilation.ContainsSyntaxTree(syntaxTree))
{
var semanticModel = compilation.GetSemanticModel(syntaxTree);
// 尝试符号解析
var symbolInfo = semanticModel.GetSymbolInfo(returnStatement.Expression);
if (symbolInfo.Symbol is IFieldSymbol fieldSymbol && fieldSymbol.IsConst && fieldSymbol.ConstantValue is uint constValue2)
{
return constValue2;
}
// 尝试常量值解析
var constantValue = semanticModel.GetConstantValue(returnStatement.Expression);
if (constantValue.HasValue && constantValue.Value is uint uintValue)
{
return uintValue;
}
}
}
}
}
return null;
}
// 辅助方法:在程序集的命名空间中递归查找指定类型
private static INamedTypeSymbol? FindTypeInAssembly(INamespaceSymbol namespaceSymbol, string targetNamespace, string typeName)
{
// 如果当前命名空间匹配目标命名空间,查找类型
if (namespaceSymbol.ToDisplayString() == targetNamespace)
{
var type = namespaceSymbol.GetTypeMembers(typeName).FirstOrDefault();
if (type != null)
{
return type;
}
}
// 递归搜索子命名空间
foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
{
var result = FindTypeInAssembly(childNamespace, targetNamespace, typeName);
if (result != null)
{
return result;
}
}
return null;
}
private enum HandlerType
{
None,
MessageHandler,
RouteMessageHandler
}
private sealed record MessageHandlerInfo(
HandlerType HandlerType,
string TypeFullName,
string TypeName,
uint? OpCode);
}
}