398 lines
15 KiB
C#
398 lines
15 KiB
C#
using System.Linq;
|
||
using System.Text;
|
||
using Fantasy.SourceGenerator.Common;
|
||
using Microsoft.CodeAnalysis;
|
||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||
#pragma warning disable CS0649 // Field is never assigned to, and will always have its default value
|
||
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
|
||
|
||
namespace Fantasy.SourceGenerator.Generators;
|
||
[Generator]
|
||
internal partial class NetworkProtocolGenerator : IIncrementalGenerator
|
||
{
|
||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||
{
|
||
var networkProtocols = context.SyntaxProvider
|
||
.CreateSyntaxProvider(
|
||
predicate: static (node, _) => IsNetworkProtocolClass(node),
|
||
transform: static (ctx, _) => GetNetworkProtocolInfo(ctx))
|
||
.Where(static info => info != null)
|
||
.Collect();
|
||
|
||
var compilationAndTypes = context.CompilationProvider.Combine(networkProtocols);
|
||
|
||
context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
|
||
{
|
||
if (!CompilationHelper.HasFantasyDefine(source.Left))
|
||
{
|
||
return;
|
||
}
|
||
|
||
if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
|
||
{
|
||
return;
|
||
}
|
||
|
||
GenerateCode(spc, source.Left, source.Right!);
|
||
});
|
||
}
|
||
|
||
#region GenerateCode
|
||
|
||
private static void GenerateCode(
|
||
SourceProductionContext context,
|
||
Compilation compilation,
|
||
IEnumerable<NetworkProtocolTypeInfo> networkProtocolTypeInfos)
|
||
{
|
||
var networkProtocolTypeInfoList = networkProtocolTypeInfos.ToList();
|
||
// 获取当前程序集名称(仅用于注释)
|
||
var assemblyName = compilation.AssemblyName ?? "Unknown";
|
||
// 生成网络网络协议类型注册的类。
|
||
GenerateNetworkProtocolTypesCode(context, assemblyName, networkProtocolTypeInfoList);
|
||
// 生成OpCode辅助方法。
|
||
GenerateNetworkProtocolOpCodeResolverCode(context, assemblyName, networkProtocolTypeInfoList);
|
||
// 生成Request消息的ResponseType辅助方法。
|
||
GenerateNetworkProtocolResponseTypesResolverCode(context, assemblyName, networkProtocolTypeInfoList);
|
||
}
|
||
|
||
private static void GenerateNetworkProtocolTypesCode(
|
||
SourceProductionContext context,
|
||
string assemblyName,
|
||
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
|
||
{
|
||
var builder = new SourceCodeBuilder();
|
||
// 添加文件头
|
||
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
|
||
// 添加 using
|
||
builder.AddUsings(
|
||
"System",
|
||
"System.Collections.Generic",
|
||
"Fantasy.Assembly",
|
||
"Fantasy.DataStructure.Collection"
|
||
);
|
||
// 开始命名空间(固定使用 Fantasy.Generated)
|
||
builder.BeginNamespace("Fantasy.Generated");
|
||
// 开始类定义(实现 IEventSystemRegistrar 接口)
|
||
builder.AddXmlComment($"Auto-generated NetworkProtocol registration class for {assemblyName}");
|
||
builder.BeginClass("NetworkProtocolRegistrar", "internal sealed", "INetworkProtocolRegistrar");
|
||
builder.BeginMethod("public List<Type> GetNetworkProtocolTypes()");
|
||
|
||
if (networkProtocolTypeInfoList.Any())
|
||
{
|
||
builder.AppendLine($"return new List<Type>({networkProtocolTypeInfoList.Count})");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var system in networkProtocolTypeInfoList)
|
||
{
|
||
builder.AppendLine($"typeof({system.FullName}),");
|
||
}
|
||
builder.Unindent();
|
||
builder.AppendLine("};");
|
||
builder.AppendLine();
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"return new List<Type>();");
|
||
}
|
||
|
||
builder.EndMethod();
|
||
builder.AppendLine();
|
||
// 结束类和命名空间
|
||
builder.EndClass();
|
||
builder.EndNamespace();
|
||
// 输出源代码
|
||
context.AddSource("NetworkProtocolRegistrar.g.cs", builder.ToString());
|
||
}
|
||
|
||
private static void GenerateNetworkProtocolOpCodeResolverCode(
|
||
SourceProductionContext context,
|
||
string assemblyName,
|
||
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
|
||
{
|
||
var routeTypeInfos = networkProtocolTypeInfoList.Where(d => d.RouteType.HasValue).ToList();
|
||
var builder = new SourceCodeBuilder();
|
||
// 添加文件头
|
||
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
|
||
// 添加 using
|
||
builder.AddUsings(
|
||
"System",
|
||
"System.Collections.Generic",
|
||
"Fantasy.Assembly",
|
||
"Fantasy.DataStructure.Collection",
|
||
"System.Runtime.CompilerServices"
|
||
);
|
||
builder.AppendLine();
|
||
// 开始命名空间(固定使用 Fantasy.Generated)
|
||
builder.BeginNamespace("Fantasy.Generated");
|
||
// 开始类定义(实现 INetworkProtocolOpCodeResolver 接口)
|
||
builder.AddXmlComment($"Auto-generated NetworkProtocolOpCodeResolverRegistrar class for {assemblyName}");
|
||
builder.BeginClass("NetworkProtocolOpCodeResolverRegistrar", "internal sealed", "INetworkProtocolOpCodeResolver");
|
||
builder.AddXmlComment($"GetOpCodeCount");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int GetOpCodeCount()");
|
||
builder.AppendLine($"return {networkProtocolTypeInfoList.Count};");
|
||
builder.EndMethod();
|
||
builder.AddXmlComment($"GetCustomRouteTypeCount");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int GetCustomRouteTypeCount()");
|
||
builder.AppendLine($"return {routeTypeInfos.Count};");
|
||
builder.EndMethod();
|
||
// 开始定义GetOpCodeType方法
|
||
builder.AddXmlComment($"GetOpCodeType");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public Type GetOpCodeType(uint opCode)");
|
||
|
||
if (networkProtocolTypeInfoList.Any())
|
||
{
|
||
builder.AppendLine("switch (opCode)");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var networkProtocolTypeInfo in networkProtocolTypeInfoList)
|
||
{
|
||
builder.AppendLine($"case {networkProtocolTypeInfo.OpCode}:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return typeof({networkProtocolTypeInfo.FullName});");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
builder.AppendLine("default:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return null!;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.AppendLine();
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"return null!;");
|
||
}
|
||
|
||
builder.EndMethod();
|
||
|
||
// 开始定义GetRouteType方法
|
||
builder.AddXmlComment($"CustomRouteType");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int? GetCustomRouteType(uint opCode)");
|
||
|
||
if (routeTypeInfos.Any())
|
||
{
|
||
builder.AppendLine("switch (opCode)");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var routeTypeInfo in routeTypeInfos)
|
||
{
|
||
builder.AppendLine($"case {routeTypeInfo.OpCode}:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return {routeTypeInfo.RouteType};");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
builder.AppendLine("default:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return null;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.AppendLine();
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"return null;");
|
||
}
|
||
|
||
builder.EndMethod();
|
||
builder.AppendLine();
|
||
// 结束类和命名空间
|
||
builder.EndClass();
|
||
builder.EndNamespace();
|
||
// 输出源代码
|
||
context.AddSource("NetworkProtocolOpCodeResolverRegistrar.g.cs", builder.ToString());
|
||
}
|
||
|
||
private static void GenerateNetworkProtocolResponseTypesResolverCode(
|
||
SourceProductionContext context,
|
||
string assemblyName,
|
||
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
|
||
{
|
||
var requestList = networkProtocolTypeInfoList.Where(d => d.ResponseType != null).ToList();
|
||
var builder = new SourceCodeBuilder();
|
||
// 添加文件头
|
||
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
|
||
// 添加 using
|
||
builder.AddUsings(
|
||
"System",
|
||
"System.Collections.Generic",
|
||
"Fantasy.Assembly",
|
||
"Fantasy.DataStructure.Collection",
|
||
"System.Runtime.CompilerServices"
|
||
);
|
||
builder.AppendLine();
|
||
// 开始命名空间(固定使用 Fantasy.Generated)
|
||
builder.BeginNamespace("Fantasy.Generated");
|
||
// 开始类定义(实现 IEventSystemRegistrar 接口)
|
||
builder.AddXmlComment($"Auto-generated NetworkProtocolResponseTypeResolverRegistrar class for {assemblyName}");
|
||
builder.BeginClass("NetworkProtocolResponseTypeResolverRegistrar", "internal sealed", "INetworkProtocolResponseTypeResolver");
|
||
builder.AddXmlComment($"GetOpCodeCount");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public int GetRequestCount()");
|
||
builder.AppendLine($"return {requestList.Count};");
|
||
builder.EndMethod();
|
||
// 开始定义GetOpCodeType方法
|
||
builder.AddXmlComment($"GetOpCodeType");
|
||
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||
builder.BeginMethod("public Type GetResponseType(uint opCode)");
|
||
|
||
if (requestList.Any())
|
||
{
|
||
builder.AppendLine("switch (opCode)");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
foreach (var request in requestList)
|
||
{
|
||
builder.AppendLine($"case {request.OpCode}:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return typeof({request.ResponseType});");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
}
|
||
builder.AppendLine("default:");
|
||
builder.AppendLine("{");
|
||
builder.Indent();
|
||
builder.AppendLine($"return null!;");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.Unindent();
|
||
builder.AppendLine("}");
|
||
builder.AppendLine();
|
||
}
|
||
else
|
||
{
|
||
builder.AppendLine($"return null!;");
|
||
}
|
||
|
||
builder.EndMethod();
|
||
builder.AppendLine();
|
||
// 结束类和命名空间
|
||
builder.EndClass();
|
||
builder.EndNamespace();
|
||
// 输出源代码
|
||
context.AddSource("NetworkProtocolResponseTypeResolverRegistrar.g.cs", builder.ToString());
|
||
}
|
||
|
||
#endregion
|
||
|
||
private static NetworkProtocolTypeInfo? GetNetworkProtocolInfo(GeneratorSyntaxContext context)
|
||
{
|
||
var classDecl = (ClassDeclarationSyntax)context.Node;
|
||
var symbol = context.SemanticModel.GetDeclaredSymbol(classDecl) as INamedTypeSymbol;
|
||
|
||
if (symbol == null || !symbol.IsInstantiable())
|
||
{
|
||
return null;
|
||
}
|
||
|
||
var baseType = symbol.BaseType;
|
||
|
||
if (baseType == null)
|
||
{
|
||
return null;
|
||
}
|
||
|
||
if (baseType.ToDisplayString() != "Fantasy.Network.Interface.AMessage")
|
||
{
|
||
return null;
|
||
}
|
||
|
||
// 获取 OpCode 方法的值
|
||
uint? opCodeValue = null;
|
||
var opCodeMethod = symbol.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
|
||
var opCodeSyntax = opCodeMethod?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
|
||
|
||
if (opCodeSyntax?.Body != null)
|
||
{
|
||
var returnStatement = opCodeSyntax.Body.DescendantNodes()
|
||
.OfType<ReturnStatementSyntax>()
|
||
.FirstOrDefault();
|
||
|
||
if (returnStatement?.Expression != null)
|
||
{
|
||
var constantValue = context.SemanticModel.GetConstantValue(returnStatement.Expression);
|
||
if (constantValue.HasValue && constantValue.Value is uint uintValue)
|
||
{
|
||
opCodeValue = uintValue;
|
||
}
|
||
}
|
||
}
|
||
|
||
if (!opCodeValue.HasValue)
|
||
{
|
||
return null;
|
||
}
|
||
|
||
// 获取 ResponseType 属性及其类型
|
||
string? responseTypeName = null;
|
||
var responseTypeProperty = symbol.GetMembers("ResponseType").OfType<IPropertySymbol>().FirstOrDefault();
|
||
if (responseTypeProperty != null)
|
||
{
|
||
// 获取 ResponseType 属性的类型(例如 G2C_TestResponse)
|
||
responseTypeName = responseTypeProperty.Type.GetFullName();
|
||
}
|
||
|
||
// 获取 RouteType 属性的值
|
||
int? routeTypeValue = null;
|
||
var routeTypeProperty = symbol.GetMembers("RouteType").OfType<IPropertySymbol>().FirstOrDefault();
|
||
var routeTypeSyntax = routeTypeProperty?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as PropertyDeclarationSyntax;
|
||
if (routeTypeSyntax?.ExpressionBody != null)
|
||
{
|
||
var constantValue = context.SemanticModel.GetConstantValue(routeTypeSyntax.ExpressionBody.Expression);
|
||
if (constantValue.HasValue && constantValue.Value is int intValue)
|
||
{
|
||
routeTypeValue = intValue;
|
||
}
|
||
}
|
||
|
||
return new NetworkProtocolTypeInfo(
|
||
symbol.GetFullName(),
|
||
opCodeValue.Value,
|
||
responseTypeName,
|
||
routeTypeValue
|
||
);
|
||
}
|
||
|
||
private static bool IsNetworkProtocolClass(SyntaxNode node)
|
||
{
|
||
if (node is not ClassDeclarationSyntax classDecl)
|
||
{
|
||
return false;
|
||
}
|
||
|
||
if (classDecl.BaseList == null)
|
||
{
|
||
return false;
|
||
}
|
||
|
||
foreach (var baseTypeSyntax in classDecl.BaseList.Types)
|
||
{
|
||
if (baseTypeSyntax.Type.ToString().Contains("AMessage"))
|
||
{
|
||
return true;
|
||
}
|
||
}
|
||
|
||
return false;
|
||
}
|
||
|
||
private sealed record NetworkProtocolTypeInfo(
|
||
string FullName,
|
||
uint OpCode,
|
||
string? ResponseType,
|
||
int? RouteType);
|
||
} |