Changeset 15973 for branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Timestamp:
- 06/28/18 11:13:37 (6 years ago)
- Location:
- branches/2522_RefactorPluginInfrastructure
- Files:
-
- 17 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2522_RefactorPluginInfrastructure
- Property svn:ignore
-
old new 24 24 protoc.exe 25 25 obj 26 .vs
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
/stable/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4 10321-10322 /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.4 13329-15286 /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 6917-7005 /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4 9070-13099 /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 4656-4721 /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 5471-5808 /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.4 5815-6180 /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 4458-4459,4462,4464 /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4 10085-11101 /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.4 6284-6795 /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.4 5060 /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 11570-12508 /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.4 11130-12721 /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4 13819-14091 /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4 8116-8789 /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.4 10202-10483 /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 5138-5162 /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.4 5175-5192 /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.4 7773-7810 /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.4 6350-6627 /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.4 6828 /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.4 10204-10479 /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 5370-5682 /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 6829-6865 /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.4 5594-5752 /branches/Weighted TSNE/3.4 15451-15531 /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.4 5959-6341 /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4 14232-14825 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 13402-15674
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r13238 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 21 21 #endregion 22 22 23 using System;24 23 using System.Linq; 25 24 using System.Threading; … … 35 34 36 35 namespace HeuristicLab.Algorithms.DataAnalysis { 37 [Item("Gradient Boosted Trees (GBT)", "Gradient boosted trees algorithm. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")]36 [Item("Gradient Boosted Trees (GBT)", "Gradient boosted trees algorithm. Specific implementation of gradient boosting for regression trees. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")] 38 37 [StorableClass] 39 38 [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 125)] 40 public class GradientBoostedTreesAlgorithm : BasicAlgorithm { 41 public override Type ProblemType { 42 get { return typeof(IRegressionProblem); } 43 } 44 public new IRegressionProblem Problem { 45 get { return (IRegressionProblem)base.Problem; } 46 set { base.Problem = value; } 47 } 48 39 public class GradientBoostedTreesAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> { 49 40 #region ParameterNames 50 41 private const string IterationsParameterName = "Iterations"; … … 204 195 table.Rows.Add(new DataRow("Loss (train)")); 205 196 table.Rows.Add(new DataRow("Loss (test)")); 197 table.Rows["Loss (train)"].VisualProperties.StartIndexZero = true; 198 table.Rows["Loss (test)"].VisualProperties.StartIndexZero = true; 199 206 200 Results.Add(new Result("Qualities", table)); 207 201 var curLoss = new DoubleValue(); … … 263 257 var classificationProblemData = new ClassificationProblemData(problemData.Dataset, 264 258 problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations); 265 classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices); 259 classificationProblemData.TrainingPartition.Start = Problem.ProblemData.TrainingPartition.Start; 260 classificationProblemData.TrainingPartition.End = Problem.ProblemData.TrainingPartition.End; 261 classificationProblemData.TestPartition.Start = Problem.ProblemData.TestPartition.Start; 262 classificationProblemData.TestPartition.End = Problem.ProblemData.TestPartition.End; 263 264 classificationModel.SetThresholdsAndClassValues(new double[] { double.NegativeInfinity, 0.0 }, new[] { 0.0, 1.0 }); 265 266 266 267 267 var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData); … … 269 269 } else { 270 270 // otherwise we produce a regression solution 271 Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));271 Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData))); 272 272 } 273 273 } -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r13157 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 96 96 weights = new List<double>(); 97 97 // add constant model 98 models.Add(new ConstantModel(f0 ));98 models.Add(new ConstantModel(f0, problemData.TargetVariable)); 99 99 weights.Add(1.0); 100 100 } … … 148 148 // for custom stepping & termination 149 149 public static IGbmState CreateGbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize = 3, double r = 0.66, double m = 0.5, double nu = 0.01) { 150 // check input variables. Only double variables are allowed. 151 var invalidInputs = 152 problemData.AllowedInputVariables.Where(name => !problemData.Dataset.VariableHasType<double>(name)); 153 if (invalidInputs.Any()) 154 throw new NotSupportedException("Gradient tree boosting only supports real-valued variables. Unsupported inputs: " + string.Join(", ", invalidInputs)); 155 150 156 return new GbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu); 151 157 } -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r13157 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 31 31 namespace HeuristicLab.Algorithms.DataAnalysis { 32 32 [StorableClass] 33 [Item("Gradient boosted tree model", "")]33 [Item("Gradient boosted trees model", "")] 34 34 // this is essentially a collection of weighted regression models 35 public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {35 public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel { 36 36 // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models 37 37 #region Backwards compatible code, remove with 3.5 … … 58 58 #endregion 59 59 60 public override IEnumerable<string> VariablesUsedForPrediction { 61 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 62 } 63 60 64 private readonly IList<IRegressionModel> models; 61 65 public IEnumerable<IRegressionModel> Models { get { return models; } } … … 77 81 } 78 82 [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")] 79 publicGradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)80 : base( "Gradient boosted tree model", string.Empty) {83 internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 84 : base(string.Empty, "Gradient boosted tree model", string.Empty) { 81 85 this.models = new List<IRegressionModel>(models); 82 86 this.weights = new List<double>(weights); … … 89 93 } 90 94 91 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {95 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 92 96 // allocate target array go over all models and add up weighted estimation for each row 93 97 if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable. … … 105 109 } 106 110 107 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 108 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 109 113 } 114 110 115 } 111 116 } -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r13157 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 21 21 #endregion 22 22 23 using System; 23 24 using System.Collections.Generic; 25 using System.Linq; 24 26 using HeuristicLab.Common; 25 27 using HeuristicLab.Core; … … 33 35 // recalculate the actual GBT model on demand 34 36 [Item("Gradient boosted tree model", "")] 35 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {37 public sealed class GradientBoostedTreesModelSurrogate : RegressionModel, IGradientBoostedTreesModel { 36 38 // don't store the actual model! 37 private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary 39 // the actual model is only recalculated when necessary 40 private readonly Lazy<IGradientBoostedTreesModel> actualModel; 41 private IGradientBoostedTreesModel ActualModel { 42 get { return actualModel.Value; } 43 } 38 44 39 45 [Storable] … … 55 61 56 62 63 public override IEnumerable<string> VariablesUsedForPrediction { 64 get { 65 return ActualModel.Models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); 66 } 67 } 68 57 69 [StorableConstructor] 58 private GradientBoostedTreesModelSurrogate(bool deserializing) : base(deserializing) { } 70 private GradientBoostedTreesModelSurrogate(bool deserializing) 71 : base(deserializing) { 72 actualModel = new Lazy<IGradientBoostedTreesModel>(() => RecalculateModel()); 73 } 59 74 60 75 private GradientBoostedTreesModelSurrogate(GradientBoostedTreesModelSurrogate original, Cloner cloner) 61 76 : base(original, cloner) { 62 if (original.actualModel != null) this.actualModel = cloner.Clone(original.actualModel); 77 IGradientBoostedTreesModel clonedModel = null; 78 if (original.ActualModel != null) clonedModel = cloner.Clone(original.ActualModel); 79 actualModel = new Lazy<IGradientBoostedTreesModel>(CreateLazyInitFunc(clonedModel)); // only capture clonedModel in the closure 63 80 64 81 this.trainingProblemData = cloner.Clone(original.trainingProblemData); … … 72 89 } 73 90 91 private Func<IGradientBoostedTreesModel> CreateLazyInitFunc(IGradientBoostedTreesModel clonedModel) { 92 return () => { 93 return clonedModel == null ? RecalculateModel() : clonedModel; 94 }; 95 } 96 74 97 // create only the surrogate model without an actual model 75 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 76 : base("Gradient boosted tree model", string.Empty) { 98 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 99 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 100 : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) { 77 101 this.trainingProblemData = trainingProblemData; 78 102 this.seed = seed; … … 86 110 87 111 // wrap an actual model in a surrograte 88 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model) 112 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 113 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, 114 IGradientBoostedTreesModel model) 89 115 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 90 this.actualModel = model;116 actualModel = new Lazy<IGradientBoostedTreesModel>(() => model); 91 117 } 92 118 … … 96 122 97 123 // forward message to actual model (recalculate model first if necessary) 98 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 99 if (actualModel == null) actualModel = RecalculateModel(); 100 return actualModel.GetEstimatedValues(dataset, rows); 124 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 125 return ActualModel.GetEstimatedValues(dataset, rows); 101 126 } 102 127 103 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {128 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 129 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 105 130 } 106 107 131 108 132 private IGradientBoostedTreesModel RecalculateModel() { … … 112 136 public IEnumerable<IRegressionModel> Models { 113 137 get { 114 if (actualModel == null) actualModel = RecalculateModel(); 115 return actualModel.Models; 138 return ActualModel.Models; 116 139 } 117 140 } … … 119 142 public IEnumerable<double> Weights { 120 143 get { 121 if (actualModel == null) actualModel = RecalculateModel(); 122 return actualModel.Weights; 144 return ActualModel.Weights; 123 145 } 124 146 } -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesSolution.cs
r13158 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 20 20 #endregion 21 21 22 using System.Collections.Generic;23 using System.Linq;24 22 using HeuristicLab.Common; 25 23 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/IGradientBoostedTreesModel.cs
r13157 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 21 21 #endregion 22 22 23 using System;24 23 using System.Collections.Generic; 25 using System.Linq;26 using HeuristicLab.Common;27 using HeuristicLab.Core;28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;29 24 using HeuristicLab.Problems.DataAnalysis; 30 25 -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs
r12875 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs
r12873 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
r12875 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 24 24 using System.Collections.Generic; 25 25 using System.Diagnostics; 26 using System.Linq;27 26 using HeuristicLab.Common; 28 27 using HeuristicLab.Core; -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/QuantileRegressionLoss.cs
r13026 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs
r12875 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 23 23 using System; 24 24 using System.Collections.Generic; 25 using System.Diagnostics;26 using System.Linq;27 25 using HeuristicLab.Common; 28 26 using HeuristicLab.Core; -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs
r12875 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r13065 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 22 22 23 23 using System; 24 using System.Collections;25 24 using System.Collections.Generic; 26 25 using System.Diagnostics; … … 129 128 130 129 // y and curPred are changed in gradient boosting 131 this.y = y; 132 this.curPred = curPred; 130 this.y = y; 131 this.curPred = curPred; 133 132 134 133 // shuffle row idx … … 137 136 int nRows = idx.Count(); 138 137 139 // shuffle variable idx138 // shuffle variable names 140 139 HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random); 141 140 … … 176 175 CreateRegressionTreeFromQueue(maxSize, lossFunction); 177 176 178 return new RegressionTreeModel(tree.ToArray() );179 } 180 181 182 // processes potential splits from the queue as long as splits are leftand the maximum size of the tree is not reached177 return new RegressionTreeModel(tree.ToArray(), problemData.TargetVariable); 178 } 179 180 181 // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached 183 182 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 184 183 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop … … 204 203 205 204 // overwrite existing leaf node with an internal node 206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx );205 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1)); 207 206 } 208 207 } -
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r13030 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * and the BEACON Center for the Study of Evolution in Action. 5 5 * … … 28 28 using HeuristicLab.Common; 29 29 using HeuristicLab.Core; 30 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 30 31 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 32 using HeuristicLab.Problems.DataAnalysis; 33 using HeuristicLab.Problems.DataAnalysis.Symbolic; 34 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; 32 35 33 36 namespace HeuristicLab.Algorithms.DataAnalysis { 34 37 [StorableClass] 35 38 [Item("RegressionTreeModel", "Represents a decision tree for regression.")] 36 public sealed class RegressionTreeModel : NamedItem, IRegressionModel { 39 public sealed class RegressionTreeModel : RegressionModel { 40 public override IEnumerable<string> VariablesUsedForPrediction { 41 get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); } 42 } 37 43 38 44 // trees are represented as a flat array … … 40 46 public readonly static string NO_VARIABLE = null; 41 47 42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1 )48 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) 43 49 : this() { 44 50 VarName = varName; … … 46 52 LeftIdx = leftIdx; 47 53 RightIdx = rightIdx; 48 } 49 50 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 51 public double Val { get; private set; } // threshold 52 public int LeftIdx { get; private set; } 53 public int RightIdx { get; private set; } 54 WeightLeft = weightLeft; 55 } 56 57 public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node 58 public double Val { get; internal set; } // threshold 59 public int LeftIdx { get; internal set; } 60 public int RightIdx { get; internal set; } 61 public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree 62 54 63 55 64 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 64 73 LeftIdx.Equals(other.LeftIdx) && 65 74 RightIdx.Equals(other.RightIdx) && 75 WeightLeft.Equals(other.WeightLeft) && 66 76 EqualStrings(VarName, other.VarName); 67 77 } else { … … 79 89 private TreeNode[] tree; 80 90 81 [Storable] 91 #region old storable format 92 // remove with HL 3.4 93 [Storable(AllowOneWay = true)] 82 94 // to prevent storing the references to data caches in nodes 83 // TODO seeminglyit is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)95 // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) 84 96 private Tuple<string, double, int, int>[] SerializedTree { 85 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 86 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 87 } 97 // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 98 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models 99 } 100 #endregion 101 #region new storable format 102 [Storable] 103 private string[] SerializedTreeVarNames { 104 get { return tree.Select(t => t.VarName).ToArray(); } 105 set { 106 if (tree == null) tree = new TreeNode[value.Length]; 107 for (int i = 0; i < value.Length; i++) { 108 tree[i].VarName = value[i]; 109 } 110 } 111 } 112 [Storable] 113 private double[] SerializedTreeValues { 114 get { return tree.Select(t => t.Val).ToArray(); } 115 set { 116 if (tree == null) tree = new TreeNode[value.Length]; 117 for (int i = 0; i < value.Length; i++) { 118 tree[i].Val = value[i]; 119 } 120 } 121 } 122 [Storable] 123 private int[] SerializedTreeLeftIdx { 124 get { return tree.Select(t => t.LeftIdx).ToArray(); } 125 set { 126 if (tree == null) tree = new TreeNode[value.Length]; 127 for (int i = 0; i < value.Length; i++) { 128 tree[i].LeftIdx = value[i]; 129 } 130 } 131 } 132 [Storable] 133 private int[] SerializedTreeRightIdx { 134 get { return tree.Select(t => t.RightIdx).ToArray(); } 135 set { 136 if (tree == null) tree = new TreeNode[value.Length]; 137 for (int i = 0; i < value.Length; i++) { 138 tree[i].RightIdx = value[i]; 139 } 140 } 141 } 142 [Storable] 143 private double[] SerializedTreeWeightLeft { 144 get { return tree.Select(t => t.WeightLeft).ToArray(); } 145 set { 146 if (tree == null) tree = new TreeNode[value.Length]; 147 for (int i = 0; i < value.Length; i++) { 148 tree[i].WeightLeft = value[i]; 149 } 150 } 151 } 152 #endregion 88 153 89 154 [StorableConstructor] … … 98 163 } 99 164 100 internal RegressionTreeModel(TreeNode[] tree )101 : base( "RegressionTreeModel", "Represents a decision tree for regression.") {165 internal RegressionTreeModel(TreeNode[] tree, string targetVariable) 166 : base(targetVariable, "RegressionTreeModel", "Represents a decision tree for regression.") { 102 167 this.tree = tree; 103 168 } … … 108 173 if (node.VarName == TreeNode.NO_VARIABLE) 109 174 return node.Val; 110 111 if (columnCache[nodeIdx][row] <= node.Val) 175 if (columnCache[nodeIdx] == null || double.IsNaN(columnCache[nodeIdx][row])) { 176 if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); 177 // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) 178 return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + 179 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); 180 } else if (columnCache[nodeIdx][row] <= node.Val) 112 181 nodeIdx = node.LeftIdx; 113 182 else … … 121 190 } 122 191 123 public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {192 public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) { 124 193 // lookup columns for variableNames in one pass over the tree to speed up evaluation later on 125 194 ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length]; … … 127 196 for (int i = 0; i < tree.Length; i++) { 128 197 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 129 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 198 // tree models also support calculating estimations if not all variables used for training are available in the dataset 199 if (ds.ColumnNames.Contains(tree[i].VarName)) 200 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 130 201 } 131 202 } … … 133 204 } 134 205 135 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {206 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 136 207 return new RegressionSolution(this, new RegressionProblemData(problemData)); 137 208 } … … 141 212 return TreeToString(0, ""); 142 213 } 214 215 /// <summary> 216 /// Transforms the tree model to a symbolic regression solution 217 /// </summary> 218 /// <param name="problemData"></param> 219 /// <returns>A new symbolic regression solution which matches the tree model</returns> 220 public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) { 221 return CreateSymbolicRegressionModel().CreateRegressionSolution(problemData); 222 } 223 224 /// <summary> 225 /// Transforms the tree model to a symbolic regression model 226 /// </summary> 227 /// <returns>A new symbolic regression model which matches the tree model</returns> 228 public SymbolicRegressionModel CreateSymbolicRegressionModel() { 229 var rootSy = new ProgramRootSymbol(); 230 var startSy = new StartSymbol(); 231 var varCondSy = new VariableCondition() { IgnoreSlope = true }; 232 var constSy = new Constant(); 233 234 var startNode = startSy.CreateTreeNode(); 235 startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy)); 236 var rootNode = rootSy.CreateTreeNode(); 237 rootNode.AddSubtree(startNode); 238 return new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter()); 239 } 240 241 private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) { 242 var curNode = treeNodes[nodeIdx]; 243 if (curNode.VarName == TreeNode.NO_VARIABLE) { 244 var node = (ConstantTreeNode)constSy.CreateTreeNode(); 245 node.Value = curNode.Val; 246 return node; 247 } else { 248 var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode(); 249 node.VariableName = curNode.VarName; 250 node.Threshold = curNode.Val; 251 252 var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy); 253 var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy); 254 node.AddSubtree(left); 255 node.AddSubtree(right); 256 return node; 257 } 258 } 259 143 260 144 261 private string TreeToString(int idx, string part) { … … 148 265 } else { 149 266 return 150 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)) 151 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)); 152 } 153 } 267 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) 268 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); 269 } 270 } 271 154 272 } 155 273 }
Note: See TracChangeset
for help on using the changeset viewer.