#region License Information /* HeuristicLab * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.PermutationEncoding; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.PluginInfrastructure; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; using HEAL.Attic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableType("FC8D8E5A-D16D-41BB-91CF-B2B35D17ADD7")] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)] [Item("Decision Tree Regression (DT)", "A regression tree / rule set learner")] public sealed class DecisionTreeRegression : FixedDataAnalysisAlgorithm { public override bool SupportsPause { get { return true; } } public const string RegressionTreeParameterVariableName = "RegressionTreeParameters"; public const string ModelVariableName = "Model"; public const string PruningSetVariableName = "PruningSet"; public const string TrainingSetVariableName = "TrainingSet"; #region Parameter names private const string GenerateRulesParameterName = "GenerateRules"; private const string HoldoutSizeParameterName = "HoldoutSize"; private const string SplitterParameterName = "Splitter"; private const string MinimalNodeSizeParameterName = "MinimalNodeSize"; private const string LeafModelParameterName = "LeafModel"; private const string PruningTypeParameterName = "PruningType"; private const string SeedParameterName = "Seed"; private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; private const string UseHoldoutParameterName = "UseHoldout"; #endregion #region Parameter properties public IFixedValueParameter GenerateRulesParameter { get { return (IFixedValueParameter)Parameters[GenerateRulesParameterName]; } } public IFixedValueParameter HoldoutSizeParameter { get { return (IFixedValueParameter)Parameters[HoldoutSizeParameterName]; } } public IConstrainedValueParameter SplitterParameter { get { return (IConstrainedValueParameter)Parameters[SplitterParameterName]; } } public IFixedValueParameter MinimalNodeSizeParameter { get { return (IFixedValueParameter)Parameters[MinimalNodeSizeParameterName]; } } public IConstrainedValueParameter LeafModelParameter { get { return (IConstrainedValueParameter)Parameters[LeafModelParameterName]; } } public IConstrainedValueParameter PruningTypeParameter { get { return (IConstrainedValueParameter)Parameters[PruningTypeParameterName]; } } public IFixedValueParameter SeedParameter { get { return (IFixedValueParameter)Parameters[SeedParameterName]; } } public IFixedValueParameter SetSeedRandomlyParameter { get { return (IFixedValueParameter)Parameters[SetSeedRandomlyParameterName]; } } public IFixedValueParameter UseHoldoutParameter { get { return (IFixedValueParameter)Parameters[UseHoldoutParameterName]; } } #endregion #region Properties public bool GenerateRules { get { return GenerateRulesParameter.Value.Value; } set { GenerateRulesParameter.Value.Value = value; } } public double HoldoutSize { get { return HoldoutSizeParameter.Value.Value; } set { HoldoutSizeParameter.Value.Value = value; } } public ISplitter Splitter { get { return SplitterParameter.Value; } // no setter because this is a constrained parameter } public int MinimalNodeSize { get { return MinimalNodeSizeParameter.Value.Value; } set { MinimalNodeSizeParameter.Value.Value = value; } } public ILeafModel LeafModel { get { return LeafModelParameter.Value; } } public IPruning Pruning { get { return PruningTypeParameter.Value; } } public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } } public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } } public bool UseHoldout { get { return UseHoldoutParameter.Value.Value; } set { UseHoldoutParameter.Value.Value = value; } } #endregion #region State [Storable] private IScope stateScope; #endregion #region Constructors and Cloning [StorableConstructor] private DecisionTreeRegression(StorableConstructorFlag _) : base(_) { } private DecisionTreeRegression(DecisionTreeRegression original, Cloner cloner) : base(original, cloner) { stateScope = cloner.Clone(stateScope); } public DecisionTreeRegression() { var modelSet = new ItemSet(ApplicationManager.Manager.GetInstances()); var pruningSet = new ItemSet(ApplicationManager.Manager.GetInstances()); var splitterSet = new ItemSet(ApplicationManager.Manager.GetInstances()); Parameters.Add(new FixedValueParameter(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created (default=false)", new BoolValue(false))); Parameters.Add(new FixedValueParameter(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning (default=20%).", new PercentValue(0.2))); Parameters.Add(new ConstrainedValueParameter(SplitterParameterName, "The type of split function used to create node splits (default='Splitter').", splitterSet, splitterSet.OfType().First())); Parameters.Add(new FixedValueParameter(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node (default=1).", new IntValue(1))); Parameters.Add(new ConstrainedValueParameter(LeafModelParameterName, "The type of model used for the nodes (default='LinearLeaf').", modelSet, modelSet.OfType().First())); Parameters.Add(new ConstrainedValueParameter(PruningTypeParameterName, "The type of pruning used (default='ComplexityPruning').", pruningSet, pruningSet.OfType().First())); Parameters.Add(new FixedValueParameter(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); Parameters.Add(new FixedValueParameter(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); Parameters.Add(new FixedValueParameter(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data (default=false).", new BoolValue(false))); Problem = new RegressionProblem(); } public override IDeepCloneable Clone(Cloner cloner) { return new DecisionTreeRegression(this, cloner); } #endregion protected override void Initialize(CancellationToken cancellationToken) { base.Initialize(cancellationToken); var random = new MersenneTwister(); if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed(); random.Reset(Seed); stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize); stateScope.Variables.Add(new Variable("Algorithm", this)); Results.AddOrUpdateResult("StateScope", stateScope); } protected override void Run(CancellationToken cancellationToken) { var model = Build(stateScope, Results, cancellationToken); AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData); } #region Static Interface public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null, bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) { if (leafModel == null) leafModel = new LinearLeaf(); if (splitter == null) splitter = new Splitter(); if (cancellationToken == null) cancellationToken = CancellationToken.None; if (pruning == null) pruning = new ComplexityPruning(); var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize); var model = Build(stateScope, results, cancellationToken.Value); return model.CreateRegressionSolution(problemData); } public static void UpdateModel(IDecisionTreeModel model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) { if (cancellationToken == null) cancellationToken = CancellationToken.None; var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random); var scope = new Scope(); scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters)); leafModel.Initialize(scope); model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value); } #endregion #region Helpers private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) { var stateScope = new Scope("RegressionTreeStateScope"); //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise var doubleVars = new HashSet(problemData.Dataset.DoubleVariables); var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray(); if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("Decision tree regression supports only double valued input or output features."); var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray(); if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x)))) throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset."); var trainingData = new Dataset(vars, doubles); var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable); pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows; pd.TrainingPartition.Start = 0; //store regression tree parameters var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter); stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams)); //initialize tree operators pruning.Initialize(stateScope); splitter.Initialize(stateScope); leafModel.Initialize(stateScope); //store unbuilt model IItem model; if (generateRules) { model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams); RegressionRuleSetModel.Initialize(stateScope); } else { model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams); } stateScope.Variables.Add(new Variable(ModelVariableName, model)); //store training & pruning indices IReadOnlyList trainingSet, pruningSet; GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet); stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray()))); stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray()))); return stateScope; } private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value; var model = (IDecisionTreeModel)stateScope.Variables[ModelVariableName].Value; var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value; var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value; if (1 > trainingRows.Length) return new PreconstructedLinearModel(new Dictionary(), 0, regressionTreeParams.TargetVariable); if (regressionTreeParams.MinLeafSize > trainingRows.Length) { var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray(); return new PreconstructedLinearModel(new Dictionary(), targets.Average(), regressionTreeParams.TargetVariable); } model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken); return model; } private static void GeneratePruningSet(IReadOnlyList allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList training, out IReadOnlyList pruning) { if (!useHoldout) { training = allrows; pruning = allrows; return; } var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random); var cut = (int)(holdoutSize * allrows.Count); pruning = perm.Take(cut).Select(i => allrows[i]).ToArray(); training = perm.Take(cut).Select(i => allrows[i]).ToArray(); } private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) { results.Add(new Result("RegressionSolution", (IItem)solution.Clone())); Dictionary frequencies = null; var tree = solution.Model as RegressionNodeTreeModel; if (tree != null) { results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree)); frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree); RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData); } var ruleSet = solution.Model as RegressionRuleSetModel; if (ruleSet != null) { results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "Rules", true)); frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet); results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData)); } //Variable frequencies if (frequencies != null) { var sum = frequencies.Values.Sum(); sum = sum == 0 ? 1 : sum; var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) { ElementNames = frequencies.Select(i => i.Key) }; results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray)); } var pruning = Pruning as ComplexityPruning; if (pruning != null && tree != null) RegressionTreeAnalyzer.PruningChart(tree, pruning, results); } #endregion } }