LinqLinker.Generator/LinqLinkGenerator.cs
|
/* LinqLinker source generator for Enumerable static methods Copyright (C) 2026 Joshua 'Joan Metek(illot)' Kidder This module is free software: you can redistribute it and/or modify it under the terms of the GNU Less General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This module is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this module. If not, see <https://www.gnu.org/licenses/>. */ using System; using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; namespace LinqLinker.Generators { [Generator] public class LinqLinkGenerator : IIncrementalGenerator { static readonly SymbolDisplayFormat s_typeFmt = new SymbolDisplayFormat( globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier | SymbolDisplayMiscellaneousOptions.EscapeKeywordIdentifiers ); public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterSourceOutput(context.CompilationProvider, static (spc, compilation) => { var enumerableType = compilation.GetTypeByMetadataName("System.Linq.Enumerable"); if (enumerableType == null) return; var methods = enumerableType.GetMembers() .OfType<IMethodSymbol>() .Where(m => m.IsStatic && m.DeclaredAccessibility == Accessibility.Public) .ToList(); var sb = new StringBuilder(); sb.AppendLine("// Auto-generated code"); sb.AppendLine("#nullable enable"); sb.AppendLine("using System;"); sb.AppendLine("using System.Collections;"); sb.AppendLine("using System.Collections.Generic;"); sb.AppendLine("using System.Linq;"); sb.AppendLine("using System.Numerics;"); sb.AppendLine(); sb.AppendLine("namespace System.Linq{"); sb.AppendLine("public readonly partial struct LinqLinker<TChainedLink> : IEnumerable<TChainedLink>"); sb.AppendLine("{"); foreach (var method in methods) { EmitStaticMethod(sb, method); EmitInstanceMethod(sb, method); sb.AppendLine(); } sb.AppendLine("}"); sb.AppendLine("}"); spc.AddSource("LinqLinks.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); }); } static string Fmt(ITypeSymbol type) => type.ToDisplayString(s_typeFmt); static ITypeSymbol? GetEnumerableElementType(ITypeSymbol type) { if (type is INamedTypeSymbol named) { if (named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) return named.TypeArguments[0]; foreach (var iface in named.AllInterfaces) { if (iface.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) return iface.TypeArguments[0]; } } return null; } static bool ShouldWrap(ITypeSymbol returnType, out ITypeSymbol? elementType) { elementType = null; if (returnType.TypeKind != TypeKind.Interface) return false; elementType = GetEnumerableElementType(returnType); return elementType != null; } static string FormatConstraints(IEnumerable<ITypeParameterSymbol> typeParams) { var parts = new List<string>(); foreach (var tp in typeParams) { var cs = new List<string>(); if (tp.HasReferenceTypeConstraint) cs.Add("class"); if (tp.HasValueTypeConstraint) cs.Add("struct"); if (tp.HasNotNullConstraint) cs.Add("notnull"); if (tp.HasUnmanagedTypeConstraint) cs.Add("unmanaged"); foreach (var ct in tp.ConstraintTypes) cs.Add(Fmt(ct)); if (tp.HasConstructorConstraint) cs.Add("new()"); if (cs.Count > 0) parts.Add($"where {tp.Name} : {string.Join(", ", cs)}"); } return parts.Count > 0 ? " " + string.Join(" ", parts) : ""; } static string FormatParam(IParameterSymbol p) { var prefix = p.RefKind switch { RefKind.Ref => "ref ", RefKind.Out => "out ", RefKind.In => "in ", _ => "" }; if (p.IsParams) prefix = "params " + prefix; return $"{prefix}{Fmt(p.Type)} {p.Name}"; } static void EmitStaticMethod(StringBuilder sb, IMethodSymbol method) { var ret = Fmt(method.ReturnType); var name = method.Name; var tparams = method.TypeParameters.Length > 0 ? "<" + string.Join(", ", method.TypeParameters.Select(tp => tp.Name)) + ">" : ""; var parms = string.Join(", ", method.Parameters.Select(FormatParam)); var constraints = FormatConstraints(method.TypeParameters); var args = string.Join(", ", method.Parameters.Select(p => { return p.RefKind switch { RefKind.Ref => $"ref {p.Name}", RefKind.Out => $"out {p.Name}", RefKind.In => $"in {p.Name}", _ => p.Name }; })); sb.AppendLine($" internal static {ret} __{name}{tparams}({parms}){constraints}"); sb.AppendLine($" => System.Linq.Enumerable.{name}{tparams}({args});"); } static void EmitInstanceMethod(StringBuilder sb, IMethodSymbol method) { bool hasSource = false; string? sourceTypeParamName = null; if (method.Parameters.Length > 0) { var firstType = method.Parameters[0].Type; var elemType = GetEnumerableElementType(firstType); if (elemType is ITypeParameterSymbol tps) { hasSource = true; sourceTypeParamName = tps.Name; } else if (firstType.SpecialType == SpecialType.System_Collections_IEnumerable) { hasSource = true; } } string Sub(string s) => sourceTypeParamName != null ? s.Replace(sourceTypeParamName, "TChainedLink") : s; // Return type string returnTypeStr; bool wrap = ShouldWrap(method.ReturnType, out var wrapElemType); if (wrap) { var enumTypeStr = Sub(Fmt(method.ReturnType)); var elemTypeStr = Sub(Fmt(wrapElemType!)); returnTypeStr = $"LinqLinker<{elemTypeStr}>?"; } else { returnTypeStr = Sub(Fmt(method.ReturnType)); } var name = method.Name; // Type params: remove the source type param var instanceTypeParams = method.TypeParameters .Where(tp => tp.Name != sourceTypeParamName) .ToList(); var tparams = instanceTypeParams.Count > 0 ? "<" + string.Join(", ", instanceTypeParams.Select(tp => tp.Name)) + ">" : ""; // Parameters: skip first if it's the source var instanceParams = hasSource ? method.Parameters.Skip(1).ToArray() : method.Parameters.ToArray(); var parms = string.Join(", ", instanceParams.Select(p => $"{Sub(Fmt(p.Type))} {p.Name}")); // Constraints for remaining type params var constraints = Sub(FormatConstraints(instanceTypeParams)); // Build the call to the static method var staticTargs = method.TypeParameters.Length > 0 ? "<" + string.Join(", ", method.TypeParameters.Select(tp => tp.Name == sourceTypeParamName ? "TChainedLink" : tp.Name)) + ">" : ""; var callArgParts = new List<string>(); if (hasSource) { var firstParamTypeStr = Sub(Fmt(method.Parameters[0].Type)); if (firstParamTypeStr == "IEnumerable<TChainedLink>" || firstParamTypeStr == "IEnumerable") callArgParts.Add("_wrapped"); else callArgParts.Add($"({firstParamTypeStr})_wrapped"); } foreach (var p in instanceParams) { callArgParts.Add(p.RefKind switch { RefKind.Ref => $"ref {p.Name}", RefKind.Out => $"out {p.Name}", RefKind.In => $"in {p.Name}", _ => p.Name }); } var callArgs = string.Join(", ", callArgParts); var callExpr = $"__{name}{staticTargs}({callArgs})"; if (wrap) { // Return new LinqLinker wrapping the result, or null if result is null var nonNullReturnType = returnTypeStr.TrimEnd('?'); sb.AppendLine($" public {returnTypeStr} {name}{tparams}({parms}){constraints}"); sb.AppendLine($" {{"); sb.AppendLine($" var __result = {callExpr};"); sb.AppendLine($" return __result is null ? null : new {nonNullReturnType}(__result);"); sb.AppendLine($" }}"); } else { sb.AppendLine($" public {returnTypeStr} {name}{tparams}({parms}){constraints}"); sb.AppendLine($" => {callExpr};"); } } } } |