框架更新
This commit is contained in:
@@ -0,0 +1,379 @@
|
||||
using Fantasy.SourceGenerator.Common;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
|
||||
namespace Fantasy.SourceGenerator.Generators
|
||||
{
|
||||
[Generator]
|
||||
public sealed class MessageHandlerGenerator : IIncrementalGenerator
|
||||
{
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
// 查找所有实现了消息相关接口的类
|
||||
var messageTypes = context.SyntaxProvider
|
||||
.CreateSyntaxProvider(
|
||||
predicate: static (node, _) => IsMessageHandlerClass(node),
|
||||
transform: static (ctx, _) => GetMessageTypeInfo(ctx))
|
||||
.Where(static info => info != null)
|
||||
.Collect();
|
||||
// 组合编译信息和找到的类型
|
||||
var compilationAndTypes = context.CompilationProvider.Combine(messageTypes);
|
||||
// 注册源代码输出
|
||||
context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
|
||||
{
|
||||
// 检查1: 是否定义了 FANTASY_NET 或 FANTASY_UNITY 预编译符号
|
||||
if (!CompilationHelper.HasFantasyDefine(source.Left))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// 检查2: 是否引用了 Fantasy 框架的核心类型
|
||||
if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
GenerateRegistrationCode(spc, source.Left, source.Right!);
|
||||
});
|
||||
}
|
||||
|
||||
private static void GenerateRegistrationCode(
|
||||
SourceProductionContext context,
|
||||
Compilation compilation,
|
||||
IEnumerable<MessageHandlerInfo> messageHandlerInfos)
|
||||
{
|
||||
var messageHandlers = new List<MessageHandlerInfo>();
|
||||
var routeMessageHandlers = new List<MessageHandlerInfo>();
|
||||
|
||||
foreach (var messageHandlerInfo in messageHandlerInfos)
|
||||
{
|
||||
switch (messageHandlerInfo.HandlerType)
|
||||
{
|
||||
case HandlerType.MessageHandler:
|
||||
{
|
||||
messageHandlers.Add(messageHandlerInfo);
|
||||
break;
|
||||
}
|
||||
case HandlerType.RouteMessageHandler:
|
||||
{
|
||||
routeMessageHandlers.Add(messageHandlerInfo);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var assemblyName = compilation.AssemblyName ?? "Unknown";
|
||||
var builder = new SourceCodeBuilder();
|
||||
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
|
||||
builder.AddUsings(
|
||||
"System",
|
||||
"System.Collections.Generic",
|
||||
"Fantasy.Assembly",
|
||||
"Fantasy.DataStructure.Dictionary",
|
||||
"Fantasy.Network.Interface",
|
||||
"Fantasy.Network",
|
||||
"Fantasy.Entitas",
|
||||
"Fantasy.Async",
|
||||
"System.Runtime.CompilerServices"
|
||||
);
|
||||
builder.AppendLine();
|
||||
builder.BeginNamespace("Fantasy.Generated");
|
||||
builder.AddXmlComment($"Auto-generated message handler registration class for {assemblyName}");
|
||||
builder.BeginClass("MessageHandlerResolverRegistrar", "internal sealed", "IMessageHandlerResolver");
|
||||
// 生成字段用于存储已注册的实例(用于 UnRegister)
|
||||
GenerateFields(builder, messageHandlers, routeMessageHandlers);
|
||||
// 生成 Register 方法
|
||||
GenerateRegistrationCode(builder, messageHandlers, routeMessageHandlers);
|
||||
// 结束类和命名空间
|
||||
builder.EndClass();
|
||||
builder.EndNamespace();
|
||||
// 输出源代码
|
||||
context.AddSource("MessageHandlerResolverRegistrar.g.cs", builder.ToString());
|
||||
}
|
||||
|
||||
private static void GenerateFields(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
|
||||
{
|
||||
foreach (var messageHandlerInfo in messageHandlers)
|
||||
{
|
||||
builder.AppendLine($"private Func<Session, uint, uint, object, FTask> message_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
|
||||
}
|
||||
|
||||
foreach (var messageHandlerInfo in routeMessageHandlers)
|
||||
{
|
||||
builder.AppendLine($"private Func<Session, Entity, uint, object, FTask> routeMessage_{messageHandlerInfo.TypeName} = new {messageHandlerInfo.TypeFullName}().Handle;");
|
||||
}
|
||||
|
||||
builder.AppendLine();
|
||||
}
|
||||
|
||||
private static void GenerateRegistrationCode(SourceCodeBuilder builder, List<MessageHandlerInfo> messageHandlers, List<MessageHandlerInfo> routeMessageHandlers)
|
||||
{
|
||||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||||
builder.BeginMethod("public int GetMessageHandlerCount()");
|
||||
builder.AppendLine($"return {messageHandlers.Count};");
|
||||
builder.EndMethod();
|
||||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||||
builder.BeginMethod("public int GetRouteMessageHandlerCount()");
|
||||
builder.AppendLine($"return {routeMessageHandlers.Count};");
|
||||
builder.EndMethod();
|
||||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||||
builder.BeginMethod("public bool MessageHandler(Session session, uint rpcId, uint protocolCode, object message)");
|
||||
if (messageHandlers.Any())
|
||||
{
|
||||
builder.AppendLine("switch (protocolCode)");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
foreach (var messageHandlerInfo in messageHandlers)
|
||||
{
|
||||
builder.AppendLine($"case {messageHandlerInfo.OpCode}:");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
builder.AppendLine($"message_{messageHandlerInfo.TypeName}(session, rpcId, protocolCode, message).Coroutine();");
|
||||
builder.AppendLine($"return true;");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
}
|
||||
builder.AppendLine("default:");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
builder.AppendLine($"return false;");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.AppendLine($"return false;");
|
||||
}
|
||||
builder.EndMethod();
|
||||
|
||||
builder.AppendLine("#if FANTASY_NET", false);
|
||||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||||
builder.BeginMethod("public async FTask<bool> RouteMessageHandler(Session session, Entity entity, uint rpcId, uint protocolCode, object message)");
|
||||
if (routeMessageHandlers.Any())
|
||||
{
|
||||
builder.AppendLine("switch (protocolCode)");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
foreach (var routeMessageHandler in routeMessageHandlers)
|
||||
{
|
||||
builder.AppendLine($"case {routeMessageHandler.OpCode}:");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
builder.AppendLine($"await routeMessage_{routeMessageHandler.TypeName}(session, entity, rpcId, message);");
|
||||
builder.AppendLine($"return true;");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
}
|
||||
builder.AppendLine("default:");
|
||||
builder.AppendLine("{");
|
||||
builder.Indent();
|
||||
builder.AppendLine($"await FTask.CompletedTask;");
|
||||
builder.AppendLine($"return false;");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
builder.Unindent();
|
||||
builder.AppendLine("}");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.AppendLine($"await FTask.CompletedTask;");
|
||||
builder.AppendLine($"return false;");
|
||||
}
|
||||
builder.EndMethod();
|
||||
builder.AppendLine("#endif", false);
|
||||
}
|
||||
|
||||
private static bool IsMessageHandlerClass(SyntaxNode node)
|
||||
{
|
||||
if (node is not ClassDeclarationSyntax classDecl)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (classDecl.BaseList == null || !classDecl.BaseList.Types.Any())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
foreach (var baseType in classDecl.BaseList.Types)
|
||||
{
|
||||
var typeName = baseType.Type.ToString();
|
||||
|
||||
if (typeName.Contains("IMessageHandler") ||
|
||||
typeName.Contains("IRouteMessageHandler") ||
|
||||
typeName.Contains("Message<") ||
|
||||
typeName.Contains("MessageRPC<") ||
|
||||
typeName.Contains("Route<") ||
|
||||
typeName.Contains("RouteRPC<") ||
|
||||
typeName.Contains("Addressable<") ||
|
||||
typeName.Contains("AddressableRPC<") ||
|
||||
typeName.Contains("Roaming<") ||
|
||||
typeName.Contains("RoamingRPC<"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private static MessageHandlerInfo? GetMessageTypeInfo(GeneratorSyntaxContext context)
|
||||
{
|
||||
var classDecl = (ClassDeclarationSyntax)context.Node;
|
||||
|
||||
if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol symbol ||
|
||||
!symbol.IsInstantiable())
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var baseType = symbol.BaseType;
|
||||
|
||||
if (baseType is not { IsGenericType: true } || baseType.TypeArguments.Length <= 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var baseTypeName = baseType.OriginalDefinition.ToDisplayString();
|
||||
|
||||
switch (baseTypeName)
|
||||
{
|
||||
case "Fantasy.Network.Interface.Message<T>":
|
||||
case "Fantasy.Network.Interface.MessageRPC<TRequest, TResponse>":
|
||||
{
|
||||
return new MessageHandlerInfo(
|
||||
HandlerType.MessageHandler,
|
||||
symbol.GetFullName(),
|
||||
symbol.Name,
|
||||
GetOpCode(context, baseType, 0));
|
||||
}
|
||||
case "Fantasy.Network.Interface.Route<TEntity, TMessage>":
|
||||
case "Fantasy.Network.Interface.RouteRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||||
case "Fantasy.Network.Interface.Addressable<TEntity, TMessage>":
|
||||
case "Fantasy.Network.Interface.AddressableRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||||
case "Fantasy.Network.Interface.Roaming<TEntity, TMessage>":
|
||||
case "Fantasy.Network.Interface.RoamingRPC<TEntity, TRouteRequest, TRouteResponse>":
|
||||
{
|
||||
return new MessageHandlerInfo(
|
||||
HandlerType.RouteMessageHandler,
|
||||
symbol.GetFullName(),
|
||||
symbol.Name,
|
||||
GetOpCode(context, baseType, 1));
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static uint? GetOpCode(GeneratorSyntaxContext context, INamedTypeSymbol baseType, int index)
|
||||
{
|
||||
if (baseType.TypeArguments.Length <= index)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var messageType = (INamedTypeSymbol)baseType.TypeArguments[index];
|
||||
var messageName = messageType.Name;
|
||||
var compilation = context.SemanticModel.Compilation;
|
||||
|
||||
// 策略1:从消息类型所在程序集中搜索 OpCode 类
|
||||
var messageAssembly = messageType.ContainingAssembly;
|
||||
var namespaceName = messageType.ContainingNamespace.ToDisplayString();
|
||||
|
||||
// 遍历程序集中的所有类型,查找 OuterOpcode 或 InnerOpcode
|
||||
var opCodeTypeNames = new[] { "OuterOpcode", "InnerOpcode" };
|
||||
foreach (var opCodeTypeName in opCodeTypeNames)
|
||||
{
|
||||
var opCodeType = FindTypeInAssembly(messageAssembly.GlobalNamespace, namespaceName, opCodeTypeName);
|
||||
if (opCodeType != null)
|
||||
{
|
||||
var opCodeField = opCodeType.GetMembers(messageName).OfType<IFieldSymbol>().FirstOrDefault();
|
||||
if (opCodeField != null && opCodeField.IsConst && opCodeField.ConstantValue is uint constValue)
|
||||
{
|
||||
return constValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 策略2:如果策略1失败,尝试从 OpCode() 方法的语法树中解析(仅适用于同项目中的消息)
|
||||
var opCodeMethod = messageType.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
|
||||
if (opCodeMethod != null)
|
||||
{
|
||||
var opCodeSyntax = opCodeMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
|
||||
if (opCodeSyntax?.Body != null)
|
||||
{
|
||||
var returnStatement = opCodeSyntax.Body.DescendantNodes()
|
||||
.OfType<ReturnStatementSyntax>()
|
||||
.FirstOrDefault();
|
||||
|
||||
if (returnStatement?.Expression != null)
|
||||
{
|
||||
var syntaxTree = opCodeSyntax.SyntaxTree;
|
||||
|
||||
if (compilation.ContainsSyntaxTree(syntaxTree))
|
||||
{
|
||||
var semanticModel = compilation.GetSemanticModel(syntaxTree);
|
||||
|
||||
// 尝试符号解析
|
||||
var symbolInfo = semanticModel.GetSymbolInfo(returnStatement.Expression);
|
||||
if (symbolInfo.Symbol is IFieldSymbol fieldSymbol && fieldSymbol.IsConst && fieldSymbol.ConstantValue is uint constValue2)
|
||||
{
|
||||
return constValue2;
|
||||
}
|
||||
|
||||
// 尝试常量值解析
|
||||
var constantValue = semanticModel.GetConstantValue(returnStatement.Expression);
|
||||
if (constantValue.HasValue && constantValue.Value is uint uintValue)
|
||||
{
|
||||
return uintValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// 辅助方法:在程序集的命名空间中递归查找指定类型
|
||||
private static INamedTypeSymbol? FindTypeInAssembly(INamespaceSymbol namespaceSymbol, string targetNamespace, string typeName)
|
||||
{
|
||||
// 如果当前命名空间匹配目标命名空间,查找类型
|
||||
if (namespaceSymbol.ToDisplayString() == targetNamespace)
|
||||
{
|
||||
var type = namespaceSymbol.GetTypeMembers(typeName).FirstOrDefault();
|
||||
if (type != null)
|
||||
{
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
||||
// 递归搜索子命名空间
|
||||
foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
|
||||
{
|
||||
var result = FindTypeInAssembly(childNamespace, targetNamespace, typeName);
|
||||
if (result != null)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private enum HandlerType
|
||||
{
|
||||
None,
|
||||
MessageHandler,
|
||||
RouteMessageHandler
|
||||
}
|
||||
|
||||
private sealed record MessageHandlerInfo(
|
||||
HandlerType HandlerType,
|
||||
string TypeFullName,
|
||||
string TypeName,
|
||||
uint? OpCode);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user