src/Microsoft.ML.DotNet.Interactive/DecisionTreeDataExtensions.cs

// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
using System;
using System.Linq;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
 
namespace Microsoft.ML.DotNet.Interactive
{
    public static class DecisionTreeDataExtensions
    {
        public static DecisionTreeData ToDecisionTreeData(this RegressionTreeEnsemble ensemble,in VBuffer<ReadOnlyMemory<char>> featureNames = default)
        {
            // just get the first tree, for now
            return ensemble.Trees.FirstOrDefault().ToDecisionTreeData(featureNames);
        }
 
        public static DecisionTreeData ToDecisionTreeData(this RegressionTree tree, in VBuffer<ReadOnlyMemory<char>> featureNames = default)
        {
            DecisionTreeData treeData = new DecisionTreeData();
 
            if (tree == null)
            {
                return treeData;
            }
 
            var nodes = new NodeData[tree.NumberOfNodes];
            var labelBuilder = new StringBuilder();
            for (int node = 0; node < tree.NumberOfNodes; node++)
            {
                nodes[node] = new NodeData();
                int featureIndex = tree.NumericalSplitFeatureIndexes[node];
                float splitThreshold = tree.NumericalSplitThresholds[node];
 
                ReadOnlyMemory<char> featureName = featureNames.GetItemOrDefault(featureIndex);
                if (!featureName.IsEmpty)
                {
                    labelBuilder.Append(featureName);
                }
                else
                {
                    labelBuilder.Append('f');
                    labelBuilder.Append(featureIndex);
                }
                labelBuilder.Append($" > ");
                labelBuilder.Append(splitThreshold.ToString("n2"));
 
                nodes[node].Label = labelBuilder.ToString();
                labelBuilder.Clear();
            }
 
            NodeData[] leaves = new NodeData[tree.NumberOfLeaves];
            for (int leaf = 0; leaf < tree.NumberOfLeaves; leaf++)
            {
                leaves[leaf] = new NodeData {Label = tree.LeafValues[leaf].ToString("n2")};
            }
 
            NodeData GetNodeData(int child)
            {
                return child >= 0
                    ? nodes[child]
                    : leaves[~child];
            }
 
            // hook the nodes up
            for (int node = 0; node < tree.NumberOfNodes; node++)
            {
                // the RightChild is the 'greater than' path, so put that first
                nodes[node].Children.Add(GetNodeData(tree.RightChild[node]));
                nodes[node].Children.Add(GetNodeData(tree.LeftChild[node]));
            }
 
            if (nodes.Length > 0)
            {
                treeData.Root = nodes[0];
            }
 
            return treeData;
        }
    }
}