Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5NodeModel.cs @ 15430

Last change on this file since 15430 was 15430, checked in by bwerth, 5 years ago

#2847 first implementation of M5'-regression

File size: 12.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableClass]
34  internal class M5NodeModel : RegressionModel {
35    #region Properties
36    [Storable]
37    internal bool IsLeaf { get; private set; }
38    [Storable]
39    internal IRegressionModel NodeModel { get; private set; }
40    [Storable]
41    internal string SplitAttr { get; private set; }
42    [Storable]
43    internal double SplitValue { get; private set; }
44    [Storable]
45    internal M5NodeModel Left { get; private set; }
46    [Storable]
47    internal M5NodeModel Right { get; private set; }
48    [Storable]
49    internal M5NodeModel Parent { get; set; }
50    [Storable]
51    internal int NumSamples { get; private set; }
52    [Storable]
53    internal int NumParam { get; set; }
54    [Storable]
55    internal int NodeModelParams { get; set; }
56    [Storable]
57    private IReadOnlyList<string> Variables { get; set; }
58
59    private const double DevFraction = 0.05;
60    #endregion
61
62    #region HLConstructors
63    [StorableConstructor]
64    protected M5NodeModel(bool deserializing) : base(deserializing) { }
65    protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
66      IsLeaf = original.IsLeaf;
67      NodeModel = cloner.Clone(original.NodeModel);
68      SplitValue = original.SplitValue;
69      SplitAttr = original.SplitAttr;
70      Left = cloner.Clone(original.Left);
71      Right = cloner.Clone(original.Right);
72      Parent = cloner.Clone(original.Parent);
73      NumParam = original.NumParam;
74      NumSamples = original.NumSamples;
75      Variables = original.Variables != null ? original.Variables.ToList() : null;
76    }
77    protected M5NodeModel(string targetAttr) : base(targetAttr) { }
78    protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) {
79      Parent = parent;
80    }
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new M5NodeModel(this, cloner);
83    }
84    public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) {
85      return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
86    }
87    private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) {
88      return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
89    }
90    #endregion
91
92    #region RegressionModel
93    public override IEnumerable<string> VariablesUsedForPrediction {
94      get { return Variables; }
95    }
96    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
97      if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
98      if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
99      return NodeModel.GetEstimatedValues(dataset, rows);
100    }
101    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
102      return new RegressionSolution(this, problemData);
103    }
104    #endregion
105
106    internal void Split(IReadOnlyList<int> rows, M5CreationParameters m5CreationParams, double globalStdDev) {
107      Variables = m5CreationParams.AllowedInputVariables.ToArray();
108      NumSamples = rows.Count;
109      Right = null;
110      Left = null;
111      NodeModel = null;
112      IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction;
113      if (IsLeaf) return;
114      //Split(rows, m5Params, globalStdDev);
115      var bestSplit = new SplitInfo();
116      var currentSplit = new SplitInfo();
117
118      //find best Attribute for the Split
119      foreach (var attr in m5CreationParams.AllowedInputVariables) {
120        var sortedData = m5CreationParams.Data.GetDoubleValues(attr, rows).Zip(m5CreationParams.Data.GetDoubleValues(TargetVariable, rows), Tuple.Create).OrderBy(x => x.Item1).ToArray();
121        currentSplit.AttributeSplit(attr, sortedData.Select(x => x.Item1).ToArray(), sortedData.Select(x => x.Item2).ToArray(), m5CreationParams);
122        if (!currentSplit.MaxImpurity.IsAlmost(bestSplit.MaxImpurity) && currentSplit.MaxImpurity > bestSplit.MaxImpurity)
123          bestSplit = (SplitInfo) currentSplit.Clone();
124      }
125
126      //if no suitable split exists => leafNode
127      if (bestSplit.SplitAttr == null || bestSplit.Position < 1 || bestSplit.Position > rows.Count - 1) {
128        IsLeaf = true;
129        return;
130      }
131
132      SplitAttr = bestSplit.SplitAttr;
133      SplitValue = bestSplit.SplitValue;
134
135      //split Dataset
136      IReadOnlyList<int> leftRows, rightRows;
137      SplitRows(rows, m5CreationParams.Data, SplitAttr, SplitValue, out leftRows, out rightRows);
138
139      if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) {
140        IsLeaf = true;
141        SplitAttr = null;
142        SplitValue = double.NaN;
143        return;
144      }
145
146      //create subtrees
147      Left = CreateNode(this, m5CreationParams);
148      Left.Split(leftRows, m5CreationParams, globalStdDev);
149      Right = CreateNode(this, m5CreationParams);
150      Right.Split(rightRows, m5CreationParams, globalStdDev);
151    }
152
153    internal bool Prune(IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) {
154      if (IsLeaf) {
155        BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
156        NumParam = NodeModelParams;
157        return true;
158      }
159      //split training & holdout data
160      IReadOnlyList<int> leftTest, rightTest;
161      SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest);
162      IReadOnlyList<int> leftTraining, rightTraining;
163      SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining);
164
165      //prune children frist
166      var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev);
167      var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev);
168      NumParam = Left.NumParam + Right.NumParam + 1;
169
170      //TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes)
171      if (!lpruned && !rpruned) return false;
172
173      BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
174
175      //check if children will be pruned
176      if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false;
177
178      //convert to leafNode
179      ((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1;
180      IsLeaf = true;
181      Right = null;
182      Left = null;
183      NumParam = NodeModelParams;
184      return true;
185    }
186
187    internal void InstallModels(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
188      if (!IsLeaf) {
189        IReadOnlyList<int> leftRows, rightRows;
190        SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows);
191        Left.InstallModels(leftRows, random, data, leafType, cancellation);
192        Right.InstallModels(rightRows, random, data, leafType, cancellation);
193        return;
194      }
195      BuildModel(rows, data, random, leafType, cancellation);
196    }
197
198    internal IEnumerable<M5NodeModel> EnumerateNodes() {
199      var queue = new Queue<M5NodeModel>();
200      queue.Enqueue(this);
201      while (queue.Count != 0) {
202        var cur = queue.Dequeue();
203        yield return cur;
204        if (cur.Left == null && cur.Right == null) continue;
205        if (cur.Left != null) queue.Enqueue(cur.Left);
206        if (cur.Right != null) queue.Enqueue(cur.Right);
207      }
208    }
209
210    internal void ToRuleNode() {
211      Parent = null;
212    }
213
214    #region Helpers
215    private double GetEstimatedValue(IDataset dataset, int row) {
216      if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
217      if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
218      return NodeModel.GetEstimatedValues(dataset, new[] {row}).First();
219    }
220
221    private void BuildModel(IReadOnlyList<int> rows, IDataset data, IRandom random, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
222      var reducedData = new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
223      var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
224      pd.TrainingPartition.Start = 0;
225      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
226
227      int noparams;
228      NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams);
229      NodeModelParams = noparams;
230      cancellation.ThrowIfCancellationRequested();
231    }
232
233    private static void SplitRows(IReadOnlyList<int> rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList<int> leftRows, out IReadOnlyList<int> rightRows) {
234      var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray();
235      leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList();
236      rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList();
237    }
238    #endregion
239
240    [StorableClass]
241    private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel {
242      #region HLConstructors
243      [StorableConstructor]
244      private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { }
245      private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { }
246      public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { }
247      public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { }
248      public override IDeepCloneable Clone(Cloner cloner) {
249        return new ConfidenceM5NodeModel(this, cloner);
250      }
251      #endregion
252
253      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
254        return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
255      }
256
257      private double GetEstimatedVariance(IDataset dataset, int row) {
258        if (!IsLeaf)
259          return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
260        return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First();
261      }
262
263      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
264        return new ConfidenceRegressionSolution(this, problemData);
265      }
266    }
267  }
268}
Note: See TracBrowser for help on using the repository browser.