[16847] | 1 | #region License Information
|
---|
| 2 | /* HeuristicLab
|
---|
[17180] | 3 | * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
[16847] | 4 | *
|
---|
| 5 | * This file is part of HeuristicLab.
|
---|
| 6 | *
|
---|
| 7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
| 8 | * it under the terms of the GNU General Public License as published by
|
---|
| 9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
| 10 | * (at your option) any later version.
|
---|
| 11 | *
|
---|
| 12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
| 13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
| 14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
| 15 | * GNU General Public License for more details.
|
---|
| 16 | *
|
---|
| 17 | * You should have received a copy of the GNU General Public License
|
---|
| 18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
| 19 | */
|
---|
| 20 | #endregion
|
---|
| 21 |
|
---|
| 22 | using System;
|
---|
[15430] | 23 | using System.Collections.Generic;
|
---|
| 24 | using System.Linq;
|
---|
| 25 | using System.Threading;
|
---|
| 26 | using HeuristicLab.Common;
|
---|
| 27 | using HeuristicLab.Core;
|
---|
| 28 | using HeuristicLab.Data;
|
---|
[15614] | 29 | using HeuristicLab.Encodings.PermutationEncoding;
|
---|
[15430] | 30 | using HeuristicLab.Optimization;
|
---|
| 31 | using HeuristicLab.Parameters;
|
---|
| 32 | using HeuristicLab.PluginInfrastructure;
|
---|
| 33 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 34 | using HeuristicLab.Random;
|
---|
[16847] | 35 | using HEAL.Attic;
|
---|
[15430] | 36 |
|
---|
| 37 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
[16847] | 38 | [StorableType("FC8D8E5A-D16D-41BB-91CF-B2B35D17ADD7")]
|
---|
[15430] | 39 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
|
---|
[17082] | 40 | [Item("Decision Tree Regression (DT)", "A regression tree / rule set learner")]
|
---|
[17080] | 41 | public sealed class DecisionTreeRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
[15830] | 42 | public override bool SupportsPause {
|
---|
| 43 | get { return true; }
|
---|
| 44 | }
|
---|
| 45 |
|
---|
| 46 | public const string RegressionTreeParameterVariableName = "RegressionTreeParameters";
|
---|
| 47 | public const string ModelVariableName = "Model";
|
---|
| 48 | public const string PruningSetVariableName = "PruningSet";
|
---|
| 49 | public const string TrainingSetVariableName = "TrainingSet";
|
---|
| 50 |
|
---|
[16847] | 51 | #region Parameter names
|
---|
[15430] | 52 | private const string GenerateRulesParameterName = "GenerateRules";
|
---|
[15614] | 53 | private const string HoldoutSizeParameterName = "HoldoutSize";
|
---|
[16847] | 54 | private const string SplitterParameterName = "Splitter";
|
---|
[15430] | 55 | private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
|
---|
[15614] | 56 | private const string LeafModelParameterName = "LeafModel";
|
---|
[15430] | 57 | private const string PruningTypeParameterName = "PruningType";
|
---|
| 58 | private const string SeedParameterName = "Seed";
|
---|
| 59 | private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
|
---|
[15614] | 60 | private const string UseHoldoutParameterName = "UseHoldout";
|
---|
[15430] | 61 | #endregion
|
---|
| 62 |
|
---|
| 63 | #region Parameter properties
|
---|
| 64 | public IFixedValueParameter<BoolValue> GenerateRulesParameter {
|
---|
[15614] | 65 | get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
|
---|
[15430] | 66 | }
|
---|
[15614] | 67 | public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
|
---|
| 68 | get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
|
---|
[15430] | 69 | }
|
---|
[16847] | 70 | public IConstrainedValueParameter<ISplitter> SplitterParameter {
|
---|
| 71 | get { return (IConstrainedValueParameter<ISplitter>)Parameters[SplitterParameterName]; }
|
---|
[15614] | 72 | }
|
---|
[15430] | 73 | public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
|
---|
[15614] | 74 | get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
|
---|
[15430] | 75 | }
|
---|
[15614] | 76 | public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
|
---|
| 77 | get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
|
---|
[15430] | 78 | }
|
---|
[15614] | 79 | public IConstrainedValueParameter<IPruning> PruningTypeParameter {
|
---|
| 80 | get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
|
---|
[15430] | 81 | }
|
---|
| 82 | public IFixedValueParameter<IntValue> SeedParameter {
|
---|
[15614] | 83 | get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
|
---|
[15430] | 84 | }
|
---|
| 85 | public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
|
---|
[15614] | 86 | get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
|
---|
[15430] | 87 | }
|
---|
[15614] | 88 | public IFixedValueParameter<BoolValue> UseHoldoutParameter {
|
---|
| 89 | get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
|
---|
| 90 | }
|
---|
[15430] | 91 | #endregion
|
---|
| 92 |
|
---|
| 93 | #region Properties
|
---|
| 94 | public bool GenerateRules {
|
---|
| 95 | get { return GenerateRulesParameter.Value.Value; }
|
---|
[16847] | 96 | set { GenerateRulesParameter.Value.Value = value; }
|
---|
[15430] | 97 | }
|
---|
[15614] | 98 | public double HoldoutSize {
|
---|
| 99 | get { return HoldoutSizeParameter.Value.Value; }
|
---|
[16847] | 100 | set { HoldoutSizeParameter.Value.Value = value; }
|
---|
[15614] | 101 | }
|
---|
[15830] | 102 | public ISplitter Splitter {
|
---|
[16847] | 103 | get { return SplitterParameter.Value; }
|
---|
| 104 | // no setter because this is a constrained parameter
|
---|
[15430] | 105 | }
|
---|
| 106 | public int MinimalNodeSize {
|
---|
| 107 | get { return MinimalNodeSizeParameter.Value.Value; }
|
---|
[16847] | 108 | set { MinimalNodeSizeParameter.Value.Value = value; }
|
---|
[15430] | 109 | }
|
---|
[15614] | 110 | public ILeafModel LeafModel {
|
---|
| 111 | get { return LeafModelParameter.Value; }
|
---|
[15430] | 112 | }
|
---|
[15614] | 113 | public IPruning Pruning {
|
---|
[15430] | 114 | get { return PruningTypeParameter.Value; }
|
---|
| 115 | }
|
---|
| 116 | public int Seed {
|
---|
| 117 | get { return SeedParameter.Value.Value; }
|
---|
[16847] | 118 | set { SeedParameter.Value.Value = value; }
|
---|
[15430] | 119 | }
|
---|
| 120 | public bool SetSeedRandomly {
|
---|
| 121 | get { return SetSeedRandomlyParameter.Value.Value; }
|
---|
[16847] | 122 | set { SetSeedRandomlyParameter.Value.Value = value; }
|
---|
[15430] | 123 | }
|
---|
[15614] | 124 | public bool UseHoldout {
|
---|
| 125 | get { return UseHoldoutParameter.Value.Value; }
|
---|
[16847] | 126 | set { UseHoldoutParameter.Value.Value = value; }
|
---|
[15614] | 127 | }
|
---|
[15430] | 128 | #endregion
|
---|
| 129 |
|
---|
[15830] | 130 | #region State
|
---|
| 131 | [Storable]
|
---|
| 132 | private IScope stateScope;
|
---|
| 133 | #endregion
|
---|
| 134 |
|
---|
[15430] | 135 | #region Constructors and Cloning
|
---|
| 136 | [StorableConstructor]
|
---|
[17080] | 137 | private DecisionTreeRegression(StorableConstructorFlag _) : base(_) { }
|
---|
| 138 | private DecisionTreeRegression(DecisionTreeRegression original, Cloner cloner) : base(original, cloner) {
|
---|
[15830] | 139 | stateScope = cloner.Clone(stateScope);
|
---|
| 140 | }
|
---|
[17080] | 141 | public DecisionTreeRegression() {
|
---|
[15614] | 142 | var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
|
---|
| 143 | var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
|
---|
[16847] | 144 | var splitterSet = new ItemSet<ISplitter>(ApplicationManager.Manager.GetInstances<ISplitter>());
|
---|
| 145 | Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created (default=false)", new BoolValue(false)));
|
---|
| 146 | Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning (default=20%).", new PercentValue(0.2)));
|
---|
[17081] | 147 | Parameters.Add(new ConstrainedValueParameter<ISplitter>(SplitterParameterName, "The type of split function used to create node splits (default='Splitter').", splitterSet, splitterSet.OfType<Splitter>().First()));
|
---|
[16847] | 148 | Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node (default=1).", new IntValue(1)));
|
---|
| 149 | Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes (default='LinearLeaf').", modelSet, modelSet.OfType<LinearLeaf>().First()));
|
---|
| 150 | Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used (default='ComplexityPruning').", pruningSet, pruningSet.OfType<ComplexityPruning>().First()));
|
---|
[15430] | 151 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
| 152 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
[16847] | 153 | Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data (default=false).", new BoolValue(false)));
|
---|
[15430] | 154 | Problem = new RegressionProblem();
|
---|
| 155 | }
|
---|
| 156 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
[17080] | 157 | return new DecisionTreeRegression(this, cloner);
|
---|
[15430] | 158 | }
|
---|
| 159 | #endregion
|
---|
| 160 |
|
---|
[15830] | 161 | protected override void Initialize(CancellationToken cancellationToken) {
|
---|
| 162 | base.Initialize(cancellationToken);
|
---|
[15430] | 163 | var random = new MersenneTwister();
|
---|
[16847] | 164 | if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed();
|
---|
[15430] | 165 | random.Reset(Seed);
|
---|
[15830] | 166 | stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize);
|
---|
| 167 | stateScope.Variables.Add(new Variable("Algorithm", this));
|
---|
| 168 | Results.AddOrUpdateResult("StateScope", stateScope);
|
---|
[15430] | 169 | }
|
---|
| 170 |
|
---|
[15830] | 171 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 172 | var model = Build(stateScope, Results, cancellationToken);
|
---|
| 173 | AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData);
|
---|
| 174 | }
|
---|
| 175 |
|
---|
[15430] | 176 | #region Static Interface
|
---|
[15830] | 177 | public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null,
|
---|
[15833] | 178 | bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
|
---|
[15614] | 179 | if (leafModel == null) leafModel = new LinearLeaf();
|
---|
[17081] | 180 | if (splitter == null) splitter = new Splitter();
|
---|
[15430] | 181 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
[15830] | 182 | if (pruning == null) pruning = new ComplexityPruning();
|
---|
[15430] | 183 |
|
---|
[15830] | 184 | var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize);
|
---|
| 185 | var model = Build(stateScope, results, cancellationToken.Value);
|
---|
| 186 | return model.CreateRegressionSolution(problemData);
|
---|
| 187 | }
|
---|
| 188 |
|
---|
[17081] | 189 | public static void UpdateModel(IDecisionTreeModel model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) {
|
---|
[15830] | 190 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
| 191 | var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random);
|
---|
| 192 | var scope = new Scope();
|
---|
| 193 | scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters));
|
---|
| 194 | leafModel.Initialize(scope);
|
---|
| 195 | model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value);
|
---|
| 196 | }
|
---|
| 197 | #endregion
|
---|
| 198 |
|
---|
| 199 | #region Helpers
|
---|
[15833] | 200 | private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) {
|
---|
[15830] | 201 | var stateScope = new Scope("RegressionTreeStateScope");
|
---|
| 202 |
|
---|
| 203 | //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
|
---|
[15430] | 204 | var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
|
---|
| 205 | var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
|
---|
[17080] | 206 | if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("Decision tree regression supports only double valued input or output features.");
|
---|
[15830] | 207 | var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
|
---|
| 208 | if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
|
---|
[17080] | 209 | throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset.");
|
---|
[15830] | 210 | var trainingData = new Dataset(vars, doubles);
|
---|
[15430] | 211 | var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
|
---|
| 212 | pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
|
---|
| 213 | pd.TrainingPartition.Start = 0;
|
---|
| 214 |
|
---|
[15830] | 215 | //store regression tree parameters
|
---|
| 216 | var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);
|
---|
| 217 | stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));
|
---|
[15430] | 218 |
|
---|
[15830] | 219 | //initialize tree operators
|
---|
| 220 | pruning.Initialize(stateScope);
|
---|
| 221 | splitter.Initialize(stateScope);
|
---|
| 222 | leafModel.Initialize(stateScope);
|
---|
[15430] | 223 |
|
---|
[15830] | 224 | //store unbuilt model
|
---|
| 225 | IItem model;
|
---|
[15833] | 226 | if (generateRules) {
|
---|
[15830] | 227 | model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
|
---|
| 228 | RegressionRuleSetModel.Initialize(stateScope);
|
---|
| 229 | }
|
---|
| 230 | else {
|
---|
| 231 | model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
|
---|
| 232 | }
|
---|
| 233 | stateScope.Variables.Add(new Variable(ModelVariableName, model));
|
---|
[15430] | 234 |
|
---|
[15830] | 235 | //store training & pruning indices
|
---|
| 236 | IReadOnlyList<int> trainingSet, pruningSet;
|
---|
| 237 | GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
|
---|
| 238 | stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
|
---|
| 239 | stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));
|
---|
[15430] | 240 |
|
---|
[15830] | 241 | return stateScope;
|
---|
[15430] | 242 | }
|
---|
| 243 |
|
---|
[15830] | 244 | private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
|
---|
[15833] | 245 | var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value;
|
---|
[17081] | 246 | var model = (IDecisionTreeModel)stateScope.Variables[ModelVariableName].Value;
|
---|
[15830] | 247 | var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value;
|
---|
| 248 | var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value;
|
---|
[15833] | 249 | if (1 > trainingRows.Length)
|
---|
[15967] | 250 | return new PreconstructedLinearModel(new Dictionary<string, double>(), 0, regressionTreeParams.TargetVariable);
|
---|
[15833] | 251 | if (regressionTreeParams.MinLeafSize > trainingRows.Length) {
|
---|
| 252 | var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray();
|
---|
[15967] | 253 | return new PreconstructedLinearModel(new Dictionary<string, double>(), targets.Average(), regressionTreeParams.TargetVariable);
|
---|
[15833] | 254 | }
|
---|
[15830] | 255 | model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken);
|
---|
| 256 | return model;
|
---|
[15430] | 257 | }
|
---|
| 258 |
|
---|
[15614] | 259 | private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
|
---|
| 260 | if (!useHoldout) {
|
---|
| 261 | training = allrows;
|
---|
| 262 | pruning = allrows;
|
---|
| 263 | return;
|
---|
| 264 | }
|
---|
| 265 | var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
|
---|
| 266 | var cut = (int)(holdoutSize * allrows.Count);
|
---|
| 267 | pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
|
---|
| 268 | training = perm.Take(cut).Select(i => allrows[i]).ToArray();
|
---|
| 269 | }
|
---|
| 270 |
|
---|
[15830] | 271 | private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) {
|
---|
| 272 | results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
|
---|
[15430] | 273 |
|
---|
[15830] | 274 | Dictionary<string, int> frequencies = null;
|
---|
| 275 |
|
---|
| 276 | var tree = solution.Model as RegressionNodeTreeModel;
|
---|
| 277 | if (tree != null) {
|
---|
| 278 | results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree));
|
---|
| 279 | frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree);
|
---|
| 280 | RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData);
|
---|
[15430] | 281 | }
|
---|
[15830] | 282 |
|
---|
| 283 | var ruleSet = solution.Model as RegressionRuleSetModel;
|
---|
| 284 | if (ruleSet != null) {
|
---|
[17080] | 285 | results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "Rules", true));
|
---|
[15830] | 286 | frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet);
|
---|
| 287 | results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData));
|
---|
[15430] | 288 | }
|
---|
| 289 |
|
---|
| 290 | //Variable frequencies
|
---|
[15830] | 291 | if (frequencies != null) {
|
---|
| 292 | var sum = frequencies.Values.Sum();
|
---|
| 293 | sum = sum == 0 ? 1 : sum;
|
---|
| 294 | var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
|
---|
| 295 | ElementNames = frequencies.Select(i => i.Key)
|
---|
| 296 | };
|
---|
| 297 | results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
|
---|
| 298 | }
|
---|
| 299 |
|
---|
| 300 | var pruning = Pruning as ComplexityPruning;
|
---|
| 301 | if (pruning != null && tree != null)
|
---|
| 302 | RegressionTreeAnalyzer.PruningChart(tree, pruning, results);
|
---|
[15430] | 303 | }
|
---|
| 304 | #endregion
|
---|
| 305 | }
|
---|
| 306 | } |
---|