LinqLinker.Generator/LinqLinkGenerator.cs

/*
    LinqLinker source generator for Enumerable static methods
    Copyright (C) 2026 Joshua 'Joan Metek(illot)' Kidder

    This module is free software: you can redistribute it and/or modify
    it under the terms of the GNU Less General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This module is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this module. If not, see <https://www.gnu.org/licenses/>.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;

namespace LinqLinker.Generators
{
    [Generator]
    public class LinqLinkGenerator : IIncrementalGenerator
    {
        static readonly SymbolDisplayFormat s_typeFmt = new SymbolDisplayFormat(
            globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted,
            typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes,
            genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
            miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes
                | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
                | SymbolDisplayMiscellaneousOptions.EscapeKeywordIdentifiers
        );

        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            context.RegisterSourceOutput(context.CompilationProvider, static (spc, compilation) =>
            {
                var enumerableType = compilation.GetTypeByMetadataName("System.Linq.Enumerable");
                if (enumerableType == null) return;

                var methods = enumerableType.GetMembers()
                    .OfType<IMethodSymbol>()
                    .Where(m => m.IsStatic && m.DeclaredAccessibility == Accessibility.Public)
                    .ToList();

                var sb = new StringBuilder();
                sb.AppendLine("// Auto-generated code");
                sb.AppendLine("#nullable enable");
                sb.AppendLine("using System;");
                sb.AppendLine("using System.Collections;");
                sb.AppendLine("using System.Collections.Generic;");
                sb.AppendLine("using System.Linq;");
                sb.AppendLine("using System.Numerics;");
                sb.AppendLine();
                sb.AppendLine("namespace System.Linq{");
                sb.AppendLine("public readonly partial struct LinqLinker<TChainedLink> : IEnumerable<TChainedLink>");
                sb.AppendLine("{");

                foreach (var method in methods)
                {
                    EmitStaticMethod(sb, method);
                    EmitInstanceMethod(sb, method);
                    sb.AppendLine();
                }

                sb.AppendLine("}");
                sb.AppendLine("}");

                spc.AddSource("LinqLinks.g.cs",
                    SourceText.From(sb.ToString(), Encoding.UTF8));
            });
        }

        static string Fmt(ITypeSymbol type) => type.ToDisplayString(s_typeFmt);

        static ITypeSymbol? GetEnumerableElementType(ITypeSymbol type)
        {
            if (type is INamedTypeSymbol named)
            {
                if (named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T)
                    return named.TypeArguments[0];

                foreach (var iface in named.AllInterfaces)
                {
                    if (iface.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T)
                        return iface.TypeArguments[0];
                }
            }
            return null;
        }

        static bool ShouldWrap(ITypeSymbol returnType, out ITypeSymbol? elementType)
        {
            elementType = null;
            if (returnType.TypeKind != TypeKind.Interface)
                return false;
            elementType = GetEnumerableElementType(returnType);
            return elementType != null;
        }

        static string FormatConstraints(IEnumerable<ITypeParameterSymbol> typeParams)
        {
            var parts = new List<string>();
            foreach (var tp in typeParams)
            {
                var cs = new List<string>();
                if (tp.HasReferenceTypeConstraint) cs.Add("class");
                if (tp.HasValueTypeConstraint) cs.Add("struct");
                if (tp.HasNotNullConstraint) cs.Add("notnull");
                if (tp.HasUnmanagedTypeConstraint) cs.Add("unmanaged");
                foreach (var ct in tp.ConstraintTypes) cs.Add(Fmt(ct));
                if (tp.HasConstructorConstraint) cs.Add("new()");
                if (cs.Count > 0)
                    parts.Add($"where {tp.Name} : {string.Join(", ", cs)}");
            }
            return parts.Count > 0 ? " " + string.Join(" ", parts) : "";
        }

        static string FormatParam(IParameterSymbol p)
        {
            var prefix = p.RefKind switch
            {
                RefKind.Ref => "ref ",
                RefKind.Out => "out ",
                RefKind.In => "in ",
                _ => ""
            };
            if (p.IsParams) prefix = "params " + prefix;
            return $"{prefix}{Fmt(p.Type)} {p.Name}";
        }

        static void EmitStaticMethod(StringBuilder sb, IMethodSymbol method)
        {
            var ret = Fmt(method.ReturnType);
            var name = method.Name;

            var tparams = method.TypeParameters.Length > 0
                ? "<" + string.Join(", ", method.TypeParameters.Select(tp => tp.Name)) + ">"
                : "";

            var parms = string.Join(", ", method.Parameters.Select(FormatParam));

            var constraints = FormatConstraints(method.TypeParameters);

            var args = string.Join(", ", method.Parameters.Select(p =>
            {
                return p.RefKind switch
                {
                    RefKind.Ref => $"ref {p.Name}",
                    RefKind.Out => $"out {p.Name}",
                    RefKind.In => $"in {p.Name}",
                    _ => p.Name
                };
            }));

            sb.AppendLine($" internal static {ret} __{name}{tparams}({parms}){constraints}");
            sb.AppendLine($" => System.Linq.Enumerable.{name}{tparams}({args});");
        }

        static void EmitInstanceMethod(StringBuilder sb, IMethodSymbol method)
        {
            bool hasSource = false;
            string? sourceTypeParamName = null;

            if (method.Parameters.Length > 0)
            {
                var firstType = method.Parameters[0].Type;
                var elemType = GetEnumerableElementType(firstType);
                if (elemType is ITypeParameterSymbol tps)
                {
                    hasSource = true;
                    sourceTypeParamName = tps.Name;
                }
                else if (firstType.SpecialType == SpecialType.System_Collections_IEnumerable)
                {
                    hasSource = true;
                }
            }

            string Sub(string s) =>
                sourceTypeParamName != null ? s.Replace(sourceTypeParamName, "TChainedLink") : s;

            // Return type
            string returnTypeStr;
            bool wrap = ShouldWrap(method.ReturnType, out var wrapElemType);
            if (wrap)
            {
                var enumTypeStr = Sub(Fmt(method.ReturnType));
                var elemTypeStr = Sub(Fmt(wrapElemType!));
                returnTypeStr = $"LinqLinker<{elemTypeStr}>";
            }
            else
            {
                returnTypeStr = Sub(Fmt(method.ReturnType));
            }

            var name = method.Name;

            // Type params: remove the source type param
            var instanceTypeParams = method.TypeParameters
                .Where(tp => tp.Name != sourceTypeParamName)
                .ToList();
            var tparams = instanceTypeParams.Count > 0
                ? "<" + string.Join(", ", instanceTypeParams.Select(tp => tp.Name)) + ">"
                : "";

            // Parameters: skip first if it's the source
            var instanceParams = hasSource
                ? method.Parameters.Skip(1).ToArray()
                : method.Parameters.ToArray();
            var parms = string.Join(", ",
                instanceParams.Select(p => $"{Sub(Fmt(p.Type))} {p.Name}"));

            // Constraints for remaining type params
            var constraints = Sub(FormatConstraints(instanceTypeParams));

            // Build the call to the static method
            var staticTargs = method.TypeParameters.Length > 0
                ? "<" + string.Join(", ", method.TypeParameters.Select(tp =>
                    tp.Name == sourceTypeParamName ? "TChainedLink" : tp.Name)) + ">"
                : "";

            var callArgParts = new List<string>();
            if (hasSource)
            {
                var firstParamTypeStr = Sub(Fmt(method.Parameters[0].Type));
                if (firstParamTypeStr == "IEnumerable<TChainedLink>" || firstParamTypeStr == "IEnumerable")
                    callArgParts.Add("_wrapped");
                else
                    callArgParts.Add($"({firstParamTypeStr})_wrapped");
            }
            foreach (var p in instanceParams)
            {
                callArgParts.Add(p.RefKind switch
                {
                    RefKind.Ref => $"ref {p.Name}",
                    RefKind.Out => $"out {p.Name}",
                    RefKind.In => $"in {p.Name}",
                    _ => p.Name
                });
            }

            var callArgs = string.Join(", ", callArgParts);
            var callExpr = $"__{name}{staticTargs}({callArgs})";

            if (wrap)
            {
                // Return new LinqLinker wrapping the result, or null if result is null
                var nonNullReturnType = returnTypeStr.TrimEnd('?');
                sb.AppendLine($" public {returnTypeStr} {name}{tparams}({parms}){constraints}");
                sb.AppendLine($" {{");
                sb.AppendLine($" var __result = {callExpr};");
                sb.AppendLine($" return __result is null ? new {nonNullReturnType}([]) : new {nonNullReturnType}(__result);");
                sb.AppendLine($" }}");
            }
            else
            {
                sb.AppendLine($" public {returnTypeStr} {name}{tparams}({parms}){constraints}");
                sb.AppendLine($" => {callExpr};");
            }
        }
    }
}