Files
Fishing2Server/Fantasy/Fantasy.Net/Fantasy.SourceGenerator/Generators/NetworkProtocolGenerator.cs
2025-10-29 17:59:43 +08:00

398 lines
15 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System.Linq;
using System.Text;
using Fantasy.SourceGenerator.Common;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
#pragma warning disable CS0649 // Field is never assigned to, and will always have its default value
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
namespace Fantasy.SourceGenerator.Generators;
[Generator]
internal partial class NetworkProtocolGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var networkProtocols = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (node, _) => IsNetworkProtocolClass(node),
transform: static (ctx, _) => GetNetworkProtocolInfo(ctx))
.Where(static info => info != null)
.Collect();
var compilationAndTypes = context.CompilationProvider.Combine(networkProtocols);
context.RegisterSourceOutput(compilationAndTypes, static (spc, source) =>
{
if (!CompilationHelper.HasFantasyDefine(source.Left))
{
return;
}
if (source.Left.GetTypeByMetadataName("Fantasy.Assembly.INetworkProtocolRegistrar") == null)
{
return;
}
GenerateCode(spc, source.Left, source.Right!);
});
}
#region GenerateCode
private static void GenerateCode(
SourceProductionContext context,
Compilation compilation,
IEnumerable<NetworkProtocolTypeInfo> networkProtocolTypeInfos)
{
var networkProtocolTypeInfoList = networkProtocolTypeInfos.ToList();
// 获取当前程序集名称(仅用于注释)
var assemblyName = compilation.AssemblyName ?? "Unknown";
// 生成网络网络协议类型注册的类。
GenerateNetworkProtocolTypesCode(context, assemblyName, networkProtocolTypeInfoList);
// 生成OpCode辅助方法。
GenerateNetworkProtocolOpCodeResolverCode(context, assemblyName, networkProtocolTypeInfoList);
// 生成Request消息的ResponseType辅助方法。
GenerateNetworkProtocolResponseTypesResolverCode(context, assemblyName, networkProtocolTypeInfoList);
}
private static void GenerateNetworkProtocolTypesCode(
SourceProductionContext context,
string assemblyName,
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
{
var builder = new SourceCodeBuilder();
// 添加文件头
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
// 添加 using
builder.AddUsings(
"System",
"System.Collections.Generic",
"Fantasy.Assembly",
"Fantasy.DataStructure.Collection"
);
// 开始命名空间(固定使用 Fantasy.Generated
builder.BeginNamespace("Fantasy.Generated");
// 开始类定义(实现 IEventSystemRegistrar 接口)
builder.AddXmlComment($"Auto-generated NetworkProtocol registration class for {assemblyName}");
builder.BeginClass("NetworkProtocolRegistrar", "internal sealed", "INetworkProtocolRegistrar");
builder.BeginMethod("public List<Type> GetNetworkProtocolTypes()");
if (networkProtocolTypeInfoList.Any())
{
builder.AppendLine($"return new List<Type>({networkProtocolTypeInfoList.Count})");
builder.AppendLine("{");
builder.Indent();
foreach (var system in networkProtocolTypeInfoList)
{
builder.AppendLine($"typeof({system.FullName}),");
}
builder.Unindent();
builder.AppendLine("};");
builder.AppendLine();
}
else
{
builder.AppendLine($"return new List<Type>();");
}
builder.EndMethod();
builder.AppendLine();
// 结束类和命名空间
builder.EndClass();
builder.EndNamespace();
// 输出源代码
context.AddSource("NetworkProtocolRegistrar.g.cs", builder.ToString());
}
private static void GenerateNetworkProtocolOpCodeResolverCode(
SourceProductionContext context,
string assemblyName,
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
{
var routeTypeInfos = networkProtocolTypeInfoList.Where(d => d.RouteType.HasValue).ToList();
var builder = new SourceCodeBuilder();
// 添加文件头
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
// 添加 using
builder.AddUsings(
"System",
"System.Collections.Generic",
"Fantasy.Assembly",
"Fantasy.DataStructure.Collection",
"System.Runtime.CompilerServices"
);
builder.AppendLine();
// 开始命名空间(固定使用 Fantasy.Generated
builder.BeginNamespace("Fantasy.Generated");
// 开始类定义(实现 INetworkProtocolOpCodeResolver 接口)
builder.AddXmlComment($"Auto-generated NetworkProtocolOpCodeResolverRegistrar class for {assemblyName}");
builder.BeginClass("NetworkProtocolOpCodeResolverRegistrar", "internal sealed", "INetworkProtocolOpCodeResolver");
builder.AddXmlComment($"GetOpCodeCount");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int GetOpCodeCount()");
builder.AppendLine($"return {networkProtocolTypeInfoList.Count};");
builder.EndMethod();
builder.AddXmlComment($"GetCustomRouteTypeCount");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int GetCustomRouteTypeCount()");
builder.AppendLine($"return {routeTypeInfos.Count};");
builder.EndMethod();
// 开始定义GetOpCodeType方法
builder.AddXmlComment($"GetOpCodeType");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public Type GetOpCodeType(uint opCode)");
if (networkProtocolTypeInfoList.Any())
{
builder.AppendLine("switch (opCode)");
builder.AppendLine("{");
builder.Indent();
foreach (var networkProtocolTypeInfo in networkProtocolTypeInfoList)
{
builder.AppendLine($"case {networkProtocolTypeInfo.OpCode}:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return typeof({networkProtocolTypeInfo.FullName});");
builder.Unindent();
builder.AppendLine("}");
}
builder.AppendLine("default:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return null!;");
builder.Unindent();
builder.AppendLine("}");
builder.Unindent();
builder.AppendLine("}");
builder.AppendLine();
}
else
{
builder.AppendLine($"return null!;");
}
builder.EndMethod();
// 开始定义GetRouteType方法
builder.AddXmlComment($"CustomRouteType");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int? GetCustomRouteType(uint opCode)");
if (routeTypeInfos.Any())
{
builder.AppendLine("switch (opCode)");
builder.AppendLine("{");
builder.Indent();
foreach (var routeTypeInfo in routeTypeInfos)
{
builder.AppendLine($"case {routeTypeInfo.OpCode}:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return {routeTypeInfo.RouteType};");
builder.Unindent();
builder.AppendLine("}");
}
builder.AppendLine("default:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return null;");
builder.Unindent();
builder.AppendLine("}");
builder.Unindent();
builder.AppendLine("}");
builder.AppendLine();
}
else
{
builder.AppendLine($"return null;");
}
builder.EndMethod();
builder.AppendLine();
// 结束类和命名空间
builder.EndClass();
builder.EndNamespace();
// 输出源代码
context.AddSource("NetworkProtocolOpCodeResolverRegistrar.g.cs", builder.ToString());
}
private static void GenerateNetworkProtocolResponseTypesResolverCode(
SourceProductionContext context,
string assemblyName,
List<NetworkProtocolTypeInfo> networkProtocolTypeInfoList)
{
var requestList = networkProtocolTypeInfoList.Where(d => d.ResponseType != null).ToList();
var builder = new SourceCodeBuilder();
// 添加文件头
builder.AppendLine(GeneratorConstants.AutoGeneratedHeader);
// 添加 using
builder.AddUsings(
"System",
"System.Collections.Generic",
"Fantasy.Assembly",
"Fantasy.DataStructure.Collection",
"System.Runtime.CompilerServices"
);
builder.AppendLine();
// 开始命名空间(固定使用 Fantasy.Generated
builder.BeginNamespace("Fantasy.Generated");
// 开始类定义(实现 IEventSystemRegistrar 接口)
builder.AddXmlComment($"Auto-generated NetworkProtocolResponseTypeResolverRegistrar class for {assemblyName}");
builder.BeginClass("NetworkProtocolResponseTypeResolverRegistrar", "internal sealed", "INetworkProtocolResponseTypeResolver");
builder.AddXmlComment($"GetOpCodeCount");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public int GetRequestCount()");
builder.AppendLine($"return {requestList.Count};");
builder.EndMethod();
// 开始定义GetOpCodeType方法
builder.AddXmlComment($"GetOpCodeType");
builder.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]");
builder.BeginMethod("public Type GetResponseType(uint opCode)");
if (requestList.Any())
{
builder.AppendLine("switch (opCode)");
builder.AppendLine("{");
builder.Indent();
foreach (var request in requestList)
{
builder.AppendLine($"case {request.OpCode}:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return typeof({request.ResponseType});");
builder.Unindent();
builder.AppendLine("}");
}
builder.AppendLine("default:");
builder.AppendLine("{");
builder.Indent();
builder.AppendLine($"return null!;");
builder.Unindent();
builder.AppendLine("}");
builder.Unindent();
builder.AppendLine("}");
builder.AppendLine();
}
else
{
builder.AppendLine($"return null!;");
}
builder.EndMethod();
builder.AppendLine();
// 结束类和命名空间
builder.EndClass();
builder.EndNamespace();
// 输出源代码
context.AddSource("NetworkProtocolResponseTypeResolverRegistrar.g.cs", builder.ToString());
}
#endregion
private static NetworkProtocolTypeInfo? GetNetworkProtocolInfo(GeneratorSyntaxContext context)
{
var classDecl = (ClassDeclarationSyntax)context.Node;
var symbol = context.SemanticModel.GetDeclaredSymbol(classDecl) as INamedTypeSymbol;
if (symbol == null || !symbol.IsInstantiable())
{
return null;
}
var baseType = symbol.BaseType;
if (baseType == null)
{
return null;
}
if (baseType.ToDisplayString() != "Fantasy.Network.Interface.AMessage")
{
return null;
}
// 获取 OpCode 方法的值
uint? opCodeValue = null;
var opCodeMethod = symbol.GetMembers("OpCode").OfType<IMethodSymbol>().FirstOrDefault();
var opCodeSyntax = opCodeMethod?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as MethodDeclarationSyntax;
if (opCodeSyntax?.Body != null)
{
var returnStatement = opCodeSyntax.Body.DescendantNodes()
.OfType<ReturnStatementSyntax>()
.FirstOrDefault();
if (returnStatement?.Expression != null)
{
var constantValue = context.SemanticModel.GetConstantValue(returnStatement.Expression);
if (constantValue.HasValue && constantValue.Value is uint uintValue)
{
opCodeValue = uintValue;
}
}
}
if (!opCodeValue.HasValue)
{
return null;
}
// 获取 ResponseType 属性及其类型
string? responseTypeName = null;
var responseTypeProperty = symbol.GetMembers("ResponseType").OfType<IPropertySymbol>().FirstOrDefault();
if (responseTypeProperty != null)
{
// 获取 ResponseType 属性的类型(例如 G2C_TestResponse
responseTypeName = responseTypeProperty.Type.GetFullName();
}
// 获取 RouteType 属性的值
int? routeTypeValue = null;
var routeTypeProperty = symbol.GetMembers("RouteType").OfType<IPropertySymbol>().FirstOrDefault();
var routeTypeSyntax = routeTypeProperty?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as PropertyDeclarationSyntax;
if (routeTypeSyntax?.ExpressionBody != null)
{
var constantValue = context.SemanticModel.GetConstantValue(routeTypeSyntax.ExpressionBody.Expression);
if (constantValue.HasValue && constantValue.Value is int intValue)
{
routeTypeValue = intValue;
}
}
return new NetworkProtocolTypeInfo(
symbol.GetFullName(),
opCodeValue.Value,
responseTypeName,
routeTypeValue
);
}
private static bool IsNetworkProtocolClass(SyntaxNode node)
{
if (node is not ClassDeclarationSyntax classDecl)
{
return false;
}
if (classDecl.BaseList == null)
{
return false;
}
foreach (var baseTypeSyntax in classDecl.BaseList.Types)
{
if (baseTypeSyntax.Type.ToString().Contains("AMessage"))
{
return true;
}
}
return false;
}
private sealed record NetworkProtocolTypeInfo(
string FullName,
uint OpCode,
string? ResponseType,
int? RouteType);
}