cs/TableFunctionWrapper.cs

#pragma warning disable DuckDBNET001
#pragma warning disable CS1701
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Management.Automation;
using System.Reflection;
using DuckDB.NET.Data;
using DuckDB.NET.Data.DataChunk.Writer;
using DuckDB.NET.Native;
 
namespace PaperinikDB
{
    /// <summary>
    /// Wraps a PowerShell ScriptBlock to be used as a DuckDB table function.
    /// </summary>
    public class TableFunctionWrapper
    {
        private readonly ScriptBlock _bindScriptBlock;
        private readonly ColumnInfo[] _columns;
        private readonly Type[] _columnTypes;
        private readonly Type[] _parameterTypes;
 
        public TableFunctionWrapper(ScriptBlock bindScriptBlock, ColumnInfo[] columns, Type[] columnTypes)
            : this(bindScriptBlock, columns, columnTypes, Array.Empty<Type>())
        {
        }
 
        public TableFunctionWrapper(ScriptBlock bindScriptBlock, ColumnInfo[] columns, Type[] columnTypes, Type[] parameterTypes)
        {
            _bindScriptBlock = bindScriptBlock ?? throw new ArgumentNullException(nameof(bindScriptBlock));
            _columns = columns ?? throw new ArgumentNullException(nameof(columns));
            _columnTypes = columnTypes ?? throw new ArgumentNullException(nameof(columnTypes));
            _parameterTypes = parameterTypes ?? Array.Empty<Type>();
        }
 
        /// <summary>
        /// Called when the table function is bound (with no input parameters).
        /// Returns the column schema and the data to iterate over.
        /// </summary>
        public TableFunction Bind()
        {
            var results = _bindScriptBlock.Invoke();
            IEnumerable data = ExtractDataFromResults(results);
            return new TableFunction(_columns, data ?? Array.Empty<object>());
        }
         
        /// <summary>
        /// Extracts the data collection from ScriptBlock.Invoke() results.
        /// Handles various return patterns from PowerShell.
        /// </summary>
        private IEnumerable ExtractDataFromResults(System.Collections.ObjectModel.Collection<PSObject> results)
        {
            if (results == null || results.Count == 0)
                return Array.Empty<object>();
             
            // Case 1: ScriptBlock returned multiple items directly (e.g., array of PSCustomObjects)
            // Each array element becomes a separate PSObject in results
            if (results.Count > 1)
            {
                // The results collection itself contains each row
                return results;
            }
             
            // Case 2: Single result - check what it is
            var singleResult = results[0];
            if (singleResult == null)
                return Array.Empty<object>();
             
            var baseObject = singleResult.BaseObject;
             
            // Case 2a: BaseObject is null or is PSObject itself (PSCustomObject case)
            if (baseObject == null || baseObject == singleResult ||
                baseObject.GetType().FullName == "System.Management.Automation.PSCustomObject")
            {
                // Single PSCustomObject - wrap in array
                return new object[] { singleResult };
            }
             
            // Case 2b: BaseObject is an enumerable (script returned @(...) that got unwrapped)
            if (baseObject is IEnumerable enumerable && !(baseObject is string))
            {
                return enumerable;
            }
             
            // Case 2c: Single non-enumerable value - wrap in array
            return new object[] { singleResult };
        }
 
        /// <summary>
        /// Called when the table function is bound (with input parameters).
        /// Returns the column schema and the data to iterate over.
        /// </summary>
        public TableFunction BindWithParams(IReadOnlyList<IDuckDBValueReader> readers)
        {
            // Collect input parameter values using the known parameter types
            var inputValues = new object[readers.Count];
            for (int i = 0; i < readers.Count; i++)
            {
                if (!readers[i].IsNull())
                {
                    // Determine the type to use for GetValue<T>
                    Type paramType = (i < _parameterTypes.Length) ? _parameterTypes[i] : typeof(object);
                     
                    // Use reflection to call the generic GetValue<T> method with the correct type
                    var getValueMethod = readers[i].GetType().GetMethod("GetValue");
                    if (getValueMethod != null && getValueMethod.IsGenericMethodDefinition)
                    {
                        var closedMethod = getValueMethod.MakeGenericMethod(paramType);
                        inputValues[i] = closedMethod.Invoke(readers[i], null);
                    }
                }
            }
 
            var results = _bindScriptBlock.Invoke(inputValues);
            IEnumerable data = ExtractDataFromResults(results);
             
            return new TableFunction(_columns, data ?? Array.Empty<object>());
        }
 
        /// <summary>
        /// Maps each row from the data enumerator to the output columns.
        /// </summary>
        public void MapRow(object row, IDuckDBDataWriter[] writers, ulong rowIndex)
        {
            if (row == null) return;
 
            try
            {
                // Handle PSObject wrapper
                if (row is PSObject psObj)
                {
                    var baseObject = psObj.BaseObject;
                     
                    // Check if this is a PSCustomObject (BaseObject is null, same as psObj, or is PSCustomObject type)
                    bool isPSCustomObject = baseObject == null ||
                                            object.ReferenceEquals(baseObject, psObj) ||
                                            baseObject.GetType().FullName == "System.Management.Automation.PSCustomObject" ||
                                            baseObject is PSObject;
                     
                    if (isPSCustomObject)
                    {
                        // Access properties by column name from the PSObject
                        for (int col = 0; col < _columns.Length && col < writers.Length; col++)
                        {
                            var propValue = psObj.Properties[_columns[col].Name]?.Value;
                            WriteValueToWriter(writers[col], propValue, _columnTypes[col], rowIndex);
                        }
                        return;
                    }
                     
                    // Not a PSCustomObject - continue with BaseObject
                    row = baseObject;
                }
 
                // Handle dictionary/hashtable
                if (row is IDictionary dict)
                {
                    for (int col = 0; col < _columns.Length && col < writers.Length; col++)
                    {
                        var colName = _columns[col].Name;
                        var value = dict.Contains(colName) ? dict[colName] : null;
                        WriteValueToWriter(writers[col], value, _columnTypes[col], rowIndex);
                    }
                    return;
                }
 
                // Handle array/list (positional)
                if (row is IList list)
                {
                    for (int col = 0; col < _columns.Length && col < writers.Length && col < list.Count; col++)
                    {
                        WriteValueToWriter(writers[col], list[col], _columnTypes[col], rowIndex);
                    }
                    return;
                }
 
                // Handle object with properties (reflection)
                var rowType = row.GetType();
                for (int col = 0; col < _columns.Length && col < writers.Length; col++)
                {
                    var prop = rowType.GetProperty(_columns[col].Name);
                    if (prop != null)
                    {
                        var value = prop.GetValue(row);
                        WriteValueToWriter(writers[col], value, _columnTypes[col], rowIndex);
                    }
                    else
                    {
                        writers[col].WriteNull(rowIndex);
                    }
                }
            }
            catch
            {
                // On error, write nulls for all columns
                for (int col = 0; col < writers.Length; col++)
                {
                    writers[col].WriteNull(rowIndex);
                }
            }
        }
 
        private void WriteValueToWriter(IDuckDBDataWriter writer, object value, Type targetType, ulong rowIndex)
        {
            if (value == null)
            {
                writer.WriteNull(rowIndex);
                return;
            }
 
            try
            {
                // Unwrap PSObject if needed
                if (value is PSObject psValue)
                {
                    value = psValue.BaseObject;
                }
 
                var convertedValue = Convert.ChangeType(value, targetType);
 
                // Use reflection to call WriteValue<T>
                var writeValueMethod = writer.GetType().GetMethod("WriteValue");
                if (writeValueMethod != null && writeValueMethod.IsGenericMethodDefinition)
                {
                    var genericMethod = writeValueMethod.MakeGenericMethod(targetType);
                    genericMethod.Invoke(writer, new object[] { convertedValue, rowIndex });
                }
                else if (writeValueMethod != null)
                {
                    writeValueMethod.Invoke(writer, new object[] { convertedValue, rowIndex });
                }
                else
                {
                    writer.WriteNull(rowIndex);
                }
            }
            catch
            {
                writer.WriteNull(rowIndex);
            }
        }
    }
}