379 lines
16 KiB
C#
379 lines
16 KiB
C#
using Fantasy.SourceGenerator.Common;
|
||
using Microsoft.CodeAnalysis;
|
||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||
|
||
namespace Fantasy.SourceGenerator.Generators
|
||
{
|
||
[Generator]
|
||
public sealed class MessageHandlerGenerator : IIncrementalGenerator
|
||
{
|
||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||
{
|
||
// 查找所有实现了消息相关接口的类
|
||
var messageTypes = context.SyntaxProvider
|
||
.CreateSyntaxProvider(
|
||
predicate: static (node, _) => IsMessageHandlerClass(node),
|
||
transform: static (ctx, _) => GetMessageTypeInfo(ctx))
|
||
.Where(static info => info != null)
|
||
.Collect();
|
||
// 组合编译信息和找到的类型
|
||
var compilationAndTypes = context.CompilationProvider.Combine(messageTypes);
|
||
// 注册源代码输出
|
||
context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
|
||
{
|
||
// 检查1: 是否定义了 FANTASY_NET 或 FANTASY_UNITY 预编译符号
|
||
if (!CompilationHelper.HasFantasyDefine(source.Left))
|
||
{
|
||
return;
|
||
}
|
||
|
||
// 检查2: 是否引用了 Fantasy 框架的核心类型
|
||
if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
|
||
{
|
||
return;
|
||
}
|
||
|
||
GenerateRegistrationCode(spc, source.Left, source.Right!);
|
||
});
|
||
}
|
||
|
||
private static void GenerateRegistrationCode(
|
||
SourceProductionContext context,
|
||
Compilation compilation,
|
||
IEnumerable<MessageHandlerInfo> messageHandlerInfos)
|
||
{
|
||
var messageHandlers = new List<MessageHandlerInfo>();
|
||
var routeMessageHandlers = new List<MessageHandlerInfo>();
|
||
|
||
foreach (var messageHandlerInfo in messageHandlerInfos)
|
||
{
|
||
switch (messageHandlerInfo.HandlerType)
|
||
{
|
||
case HandlerType.MessageHandler:
|
||
{
|
||
messageHandlers.Add(messageHandlerInfo);
|
||
break;
|
||
}
|
||
case HandlerType.RouteMessageHandler:
|
||
{
|
||
routeMessageHandlers.Add(messageHandlerInfo);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
var assemblyName = compilation.AssemblyName ?? "Unknown";
|
||
var builder = new SourceCodeBuilder();
|
||
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
|
||
builder.AddUsings(
|
||
"System",
|
||
"System.Collections.Generic",
|
||
"Fantasy.Assembly",
|
||
"Fantasy.DataStructure.Dictionary",
|
||
"Fantasy.Network.Interface",
|
||
"Fantasy.Network",
|
||
"Fantasy.Entitas",
|
||
"Fantasy.Async",
|
||
"System.Runtime.CompilerServices"
|
||
);
|
||
builder.AppendLine();
|
||
builder.BeginNamespace("Fantasy.Generated");
|
||
builder.AddXmlComment($"Auto-generated message handler registration class for {assemblyName}");
|
||
builder.BeginClass("MessageHandlerResolverRegistrar", "internal sealed", "IMessageHandlerResolver");
|
||
// 生成字段用于存储已注册的实例(用于 UnRegister)
|
||
GenerateFields(builder, messageHandlers, routeMessageHandlers);
|
||
// 生成 Register 方法
|
||
GenerateRegistrationCode(builder, messageHandlers, routeMessageHandlers);
|
||
// 结束类和命名空间
|
||
builder.EndClass();
|
||
builder.EndNamespace();
|
||
// 输出源代码
|
||
context.AddSource("MessageHandlerResolverRegistrar.g.cs", builder.ToString());
|
||
}
|
||
|
||
private static void GenerateFields(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
|
||
{
|
||
foreach (var messageHandlerInfo in messageHandlers)
|
||
{
|
||
builder.AppendLine($"private Func<Session, uint, uint, object, FTask> message_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
|
||
}
|
||
|
||
foreach (var messageHandlerInfo in routeMessageHandlers)
|
||
{
|
||
builder.AppendLine($"private Func<Session, Entity, uint, object, FTask> routeMessage_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
|
||
}
|
||
|
||
builder.AppendLine();
|
||
}
|
||
|
||
private static void GenerateRegistrationCode(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
|
||
{
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int GetMessageHandlerCount()");
|
||
builder.AppendLine($"return {messageHandlers.Count};");
|
||
builder.EndMethod();
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int GetRouteMessageHandlerCount()");
|
||
builder.AppendLine($"return {routeMessageHandlers.Count};");
|
||
builder.EndMethod();
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public bool MessageHandler(Session session, uint rpcId, uint protocolCode, object message)");
|
||
if (messageHandlers.Any())
|
||
{
|
||
builder.AppendLine("switch (protocolCode)");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var messageHandlerInfo in messageHandlers)
|
||
{
|
||
builder.AppendLine($"case {messageHandlerInfo.OpCode}:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"message_{messageHandlerInfo.TypeName}(session, rpcId, protocolCode, message).Coroutine();");
|
||
builder.AppendLine($"return true;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
builder.AppendLine("default:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return false;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"return false;");
|
||
}
|
||
builder.EndMethod();
|
||
|
||
builder.AppendLine("#if FANTASY_NET", false);
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public async FTask<bool> RouteMessageHandler(Session session, Entity entity, uint rpcId, uint protocolCode, object message)");
|
||
if (routeMessageHandlers.Any())
|
||
{
|
||
builder.AppendLine("switch (protocolCode)");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var routeMessageHandler in routeMessageHandlers)
|
||
{
|
||
builder.AppendLine($"case {routeMessageHandler.OpCode}:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"await routeMessage_{routeMessageHandler.TypeName}(session, entity, rpcId, message);");
|
||
builder.AppendLine($"return true;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
builder.AppendLine("default:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"await FTask.CompletedTask;");
|
||
builder.AppendLine($"return false;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"await FTask.CompletedTask;");
|
||
builder.AppendLine($"return false;");
|
||
}
|
||
builder.EndMethod();
|
||
builder.AppendLine("#endif", false);
|
||
}
|
||
|
||
private static bool IsMessageHandlerClass(SyntaxNode node)
|
||
{
|
||
if (node is not ClassDeclarationSyntax classDecl)
|
||
{
|
||
return false;
|
||
}
|
||
|
||
if (classDecl.BaseList == null || !classDecl.BaseList.Types.Any())
|
||
{
|
||
return false;
|
||
}
|
||
|
||
foreach (var baseType in classDecl.BaseList.Types)
|
||
{
|
||
var typeName = baseType.Type.ToString();
|
||
|
||
if (typeName.Contains("IMessageHandler") ||
|
||
typeName.Contains("IRouteMessageHandler") ||
|
||
typeName.Contains("Message<") ||
|
||
typeName.Contains("MessageRPC<") ||
|
||
typeName.Contains("Route<") ||
|
||
typeName.Contains("RouteRPC<") ||
|
||
typeName.Contains("Addressable<") ||
|
||
typeName.Contains("AddressableRPC<") ||
|
||
typeName.Contains("Roaming<") ||
|
||
typeName.Contains("RoamingRPC<"))
|
||
{
|
||
return true;
|
||
}
|
||
}
|
||
|
||
return false;
|
||
}
|
||
|
||
private static MessageHandlerInfo? GetMessageTypeInfo(GeneratorSyntaxContext context)
|
||
{
|
||
var classDecl = (ClassDeclarationSyntax)context.Node;
|
||
|
||
if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol ||
|
||
!symbol.IsInstantiable())
|
||
{
|
||
return null;
|
||
}
|
||
|
||
var baseType = symbol.BaseType;
|
||
|
||
if (baseType is not { IsGenericType: true } || baseType.TypeArguments.Length <= 0)
|
||
{
|
||
return null;
|
||
}
|
||
|
||
var baseTypeName = baseType.OriginalDefinition.ToDisplayString();
|
||
|
||
switch (baseTypeName)
|
||
{
|
||
case "Fantasy.Network.Interface.Message<T>":
|
||
case "Fantasy.Network.Interface.MessageRPC<TRequest, TResponse>":
|
||
{
|
||
return new MessageHandlerInfo(
|
||
HandlerType.MessageHandler,
|
||
symbol.GetFullName(),
|
||
symbol.Name,
|
||
GetOpCode(context, baseType, 0));
|
||
}
|
||
case "Fantasy.Network.Interface.Route<TEntity, TMessage>":
|
||
case "Fantasy.Network.Interface.RouteRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||
case "Fantasy.Network.Interface.Addressable<TEntity, TMessage>":
|
||
case "Fantasy.Network.Interface.AddressableRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||
case "Fantasy.Network.Interface.Roaming<TEntity, TMessage>":
|
||
case "Fantasy.Network.Interface.RoamingRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||
{
|
||
return new MessageHandlerInfo(
|
||
HandlerType.RouteMessageHandler,
|
||
symbol.GetFullName(),
|
||
symbol.Name,
|
||
GetOpCode(context, baseType, 1));
|
||
}
|
||
}
|
||
|
||
return null;
|
||
}
|
||
|
||
private static uint? GetOpCode(GeneratorSyntaxContext context, INamedTypeSymbol baseType, int index)
|
||
{
|
||
if (baseType.TypeArguments.Length <= index)
|
||
{
|
||
return null;
|
||
}
|
||
|
||
var messageType = (INamedTypeSymbol)baseType.TypeArguments[index];
|
||
var messageName = messageType.Name;
|
||
var compilation = context.SemanticModel.Compilation;
|
||
|
||
// 策略1:从消息类型所在程序集中搜索 OpCode 类
|
||
var messageAssembly = messageType.ContainingAssembly;
|
||
var namespaceName = messageType.ContainingNamespace.ToDisplayString();
|
||
|
||
// 遍历程序集中的所有类型,查找 OuterOpcode 或 InnerOpcode
|
||
var opCodeTypeNames = new[] { "OuterOpcode", "InnerOpcode" };
|
||
foreach (var opCodeTypeName in opCodeTypeNames)
|
||
{
|
||
var opCodeType = FindTypeInAssembly(messageAssembly.GlobalNamespace, namespaceName, opCodeTypeName);
|
||
if (opCodeType != null)
|
||
{
|
||
var opCodeField = opCodeType.GetMembers(messageName).OfType<IFieldSymbol>().FirstOrDefault();
|
||
if (opCodeField != null && opCodeField.IsConst && opCodeField.ConstantValue is uint constValue)
|
||
{
|
||
return constValue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 策略2:如果策略1失败,尝试从 OpCode() 方法的语法树中解析(仅适用于同项目中的消息)
|
||
var opCodeMethod = messageType.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
|
||
if (opCodeMethod != null)
|
||
{
|
||
var opCodeSyntax = opCodeMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
|
||
if (opCodeSyntax?.Body != null)
|
||
{
|
||
var returnStatement = opCodeSyntax.Body.DescendantNodes()
|
||
.OfType<ReturnStatementSyntax>()
|
||
.FirstOrDefault();
|
||
|
||
if (returnStatement?.Expression != null)
|
||
{
|
||
var syntaxTree = opCodeSyntax.SyntaxTree;
|
||
|
||
if (compilation.ContainsSyntaxTree(syntaxTree))
|
||
{
|
||
var semanticModel = compilation.GetSemanticModel(syntaxTree);
|
||
|
||
// 尝试符号解析
|
||
var symbolInfo = semanticModel.GetSymbolInfo(returnStatement.Expression);
|
||
if (symbolInfo.Symbol is IFieldSymbol fieldSymbol && fieldSymbol.IsConst && fieldSymbol.ConstantValue is uint constValue2)
|
||
{
|
||
return constValue2;
|
||
}
|
||
|
||
// 尝试常量值解析
|
||
var constantValue = semanticModel.GetConstantValue(returnStatement.Expression);
|
||
if (constantValue.HasValue && constantValue.Value is uint uintValue)
|
||
{
|
||
return uintValue;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return null;
|
||
}
|
||
|
||
// 辅助方法:在程序集的命名空间中递归查找指定类型
|
||
private static INamedTypeSymbol? FindTypeInAssembly(INamespaceSymbol namespaceSymbol, string targetNamespace, string typeName)
|
||
{
|
||
// 如果当前命名空间匹配目标命名空间,查找类型
|
||
if (namespaceSymbol.ToDisplayString() == targetNamespace)
|
||
{
|
||
var type = namespaceSymbol.GetTypeMembers(typeName).FirstOrDefault();
|
||
if (type != null)
|
||
{
|
||
return type;
|
||
}
|
||
}
|
||
|
||
// 递归搜索子命名空间
|
||
foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
|
||
{
|
||
var result = FindTypeInAssembly(childNamespace, targetNamespace, typeName);
|
||
if (result != null)
|
||
{
|
||
return result;
|
||
}
|
||
}
|
||
|
||
return null;
|
||
}
|
||
|
||
private enum HandlerType
|
||
{
|
||
None,
|
||
MessageHandler,
|
||
RouteMessageHandler
|
||
}
|
||
|
||
private sealed record MessageHandlerInfo(
|
||
HandlerType HandlerType,
|
||
string TypeFullName,
|
||
string TypeName,
|
||
uint? OpCode);
|
||
}
|
||
} |