Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/LeafTypes/LeafBase.cs

Last change on this file was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 7.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Threading;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Parameters;
29using HeuristicLab.Problems.DataAnalysis;
30using HEAL.Attic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableType("F3A9CCD4-975F-4F55-BE24-3A3E932591F6")]
34  public abstract class LeafBase : ParameterizedNamedItem, ILeafModel {
35    public const string LeafBuildingStateVariableName = "LeafBuildingState";
36    public const string UseDampeningParameterName = "UseDampening";
37    public const string DampeningParameterName = "DampeningStrength";
38
39    public IFixedValueParameter<DoubleValue> DampeningParameter {
40      get { return (IFixedValueParameter<DoubleValue>)Parameters[DampeningParameterName]; }
41    }
42    public IFixedValueParameter<BoolValue> UseDampeningParameter {
43      get { return (IFixedValueParameter<BoolValue>)Parameters[UseDampeningParameterName]; }
44    }
45
46    public bool UseDampening {
47      get { return UseDampeningParameter.Value.Value; }
48      set { UseDampeningParameter.Value.Value = value; }
49    }
50    public double Dampening {
51      get { return DampeningParameter.Value.Value; }
52      set { DampeningParameter.Value.Value = value; }
53    }
54
55    #region Constructors & Cloning
56    [StorableConstructor]
57    protected LeafBase(StorableConstructorFlag _) : base(_) { }
58    protected LeafBase(LeafBase original, Cloner cloner) : base(original, cloner) { }
59    protected LeafBase() {
60      Parameters.Add(new FixedValueParameter<BoolValue>(UseDampeningParameterName, "Whether logistic dampening should be used to prevent extreme extrapolation (default=false)", new BoolValue(false)));
61      Parameters.Add(new FixedValueParameter<DoubleValue>(DampeningParameterName, "Determines the strength of logistic dampening. Must be > 0.0. Larger numbers lead to more conservative predictions. (default=1.5)", new DoubleValue(1.5)));
62    }
63    #endregion
64
65    #region IModelType
66    public abstract bool ProvidesConfidence { get; }
67    public abstract int MinLeafSize(IRegressionProblemData pd);
68
69    public void Initialize(IScope states) {
70      states.Variables.Add(new Variable(LeafBuildingStateVariableName, new LeafBuildingState()));
71    }
72
73    public void Build(RegressionNodeTreeModel tree, IReadOnlyList<int> trainingRows, IScope stateScope, CancellationToken cancellationToken) {
74      var parameters = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
75      var state = (LeafBuildingState)stateScope.Variables[LeafBuildingStateVariableName].Value;
76
77      if (state.Code == 0) {
78        state.FillLeafs(tree, trainingRows, parameters.Data);
79        state.Code = 1;
80      }
81      while (state.nodeQueue.Count != 0) {
82        var n = state.nodeQueue.Peek();
83        var t = state.trainingRowsQueue.Peek();
84        int numP;
85        n.SetLeafModel(BuildModel(t, parameters, cancellationToken, out numP));
86        state.nodeQueue.Dequeue();
87        state.trainingRowsQueue.Dequeue();
88      }
89    }
90
91    public IRegressionModel BuildModel(IReadOnlyList<int> rows, RegressionTreeParameters parameters, CancellationToken cancellation, out int numberOfParameters) {
92      var reducedData = RegressionTreeUtilities.ReduceDataset(parameters.Data, rows, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
93      var pd = new RegressionProblemData(reducedData, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
94      pd.TrainingPartition.Start = 0;
95      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
96
97      int numP;
98      var model = Build(pd, parameters.Random, cancellation, out numP);
99      if (UseDampening && Dampening > 0.0) {
100        model = DampenedModel.DampenModel(model, pd, Dampening);
101      }
102
103      numberOfParameters = numP;
104      cancellation.ThrowIfCancellationRequested();
105      return model;
106    }
107
108    public abstract IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int numberOfParameters);
109    #endregion
110
111    [StorableType("495243C0-6C15-4328-B30D-FFBFA0F54DCB")]
112    public class LeafBuildingState : Item {
113      [Storable]
114      private RegressionNodeModel[] storableNodeQueue { get { return nodeQueue.ToArray(); } set { nodeQueue = new Queue<RegressionNodeModel>(value); } }
115      public Queue<RegressionNodeModel> nodeQueue;
116      [Storable]
117      private IReadOnlyList<int>[] storabletrainingRowsQueue { get { return trainingRowsQueue.ToArray(); } set { trainingRowsQueue = new Queue<IReadOnlyList<int>>(value); } }
118      public Queue<IReadOnlyList<int>> trainingRowsQueue;
119
120      //State.Code values denote the current action (for pausing)
121      //0...nothing has been done;
122      //1...building models;
123      [Storable]
124      public int Code = 0;
125
126      #region HLConstructors & Cloning
127      [StorableConstructor]
128      protected LeafBuildingState(StorableConstructorFlag _) : base(_) { }
129      protected LeafBuildingState(LeafBuildingState original, Cloner cloner) : base(original, cloner) {
130        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
131        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
132        Code = original.Code;
133      }
134      public LeafBuildingState() {
135        nodeQueue = new Queue<RegressionNodeModel>();
136        trainingRowsQueue = new Queue<IReadOnlyList<int>>();
137      }
138      public override IDeepCloneable Clone(Cloner cloner) {
139        return new LeafBuildingState(this, cloner);
140      }
141      #endregion
142
143      public void FillLeafs(RegressionNodeTreeModel tree, IReadOnlyList<int> trainingRows, IDataset data) {
144        var helperQueue = new Queue<RegressionNodeModel>();
145        var trainingHelperQueue = new Queue<IReadOnlyList<int>>();
146        nodeQueue.Clear();
147        trainingRowsQueue.Clear();
148
149        helperQueue.Enqueue(tree.Root);
150        trainingHelperQueue.Enqueue(trainingRows);
151
152        while (helperQueue.Count != 0) {
153          var n = helperQueue.Dequeue();
154          var t = trainingHelperQueue.Dequeue();
155          if (n.IsLeaf) {
156            nodeQueue.Enqueue(n);
157            trainingRowsQueue.Enqueue(t);
158            continue;
159          }
160
161          IReadOnlyList<int> leftTraining, rightTraining;
162          RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
163
164          helperQueue.Enqueue(n.Left);
165          helperQueue.Enqueue(n.Right);
166          trainingHelperQueue.Enqueue(leftTraining);
167          trainingHelperQueue.Enqueue(rightTraining);
168        }
169      }
170    }
171  }
172}
Note: See TracBrowser for help on using the repository browser.