Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5NodeModel.cs @ 15614

Last change on this file since 15614 was 15614, checked in by bwerth, 6 years ago

#2847 made changes to M5 according to review comments

File size: 8.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableClass]
34  internal class M5NodeModel : RegressionModel {
35    #region Properties
36    [Storable]
37    internal bool IsLeaf { get; private set; }
38    [Storable]
39    internal IRegressionModel Model { get; set; }
40    [Storable]
41    internal string SplitAttribute { get; private set; }
42    [Storable]
43    internal double SplitValue { get; private set; }
44    [Storable]
45    internal M5NodeModel Left { get; private set; }
46    [Storable]
47    internal M5NodeModel Right { get; private set; }
48    [Storable]
49    internal M5NodeModel Parent { get; private set; }
50    [Storable]
51    internal int NumSamples { get; private set; }
52    [Storable]
53    private IReadOnlyList<string> variables;
54    #endregion
55
56    #region HLConstructors
57    [StorableConstructor]
58    protected M5NodeModel(bool deserializing) : base(deserializing) { }
59    protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
60      IsLeaf = original.IsLeaf;
61      Model = cloner.Clone(original.Model);
62      SplitValue = original.SplitValue;
63      SplitAttribute = original.SplitAttribute;
64      Left = cloner.Clone(original.Left);
65      Right = cloner.Clone(original.Right);
66      Parent = cloner.Clone(original.Parent);
67      NumSamples = original.NumSamples;
68      variables = original.variables != null ? original.variables.ToList() : null;
69    }
70    private M5NodeModel(string targetAttr) : base(targetAttr) { }
71    private M5NodeModel(M5NodeModel parent) : this(parent.TargetVariable) {
72      Parent = parent;
73    }
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new M5NodeModel(this, cloner);
76    }
77    public static M5NodeModel CreateNode(string targetAttr, M5Parameters m5Params) {
78      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
79    }
80    private static M5NodeModel CreateNode(M5NodeModel parent, M5Parameters m5Params) {
81      return m5Params.LeafModel.ProvidesConfidence ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
82    }
83    #endregion
84
85    #region RegressionModel
86    public override IEnumerable<string> VariablesUsedForPrediction {
87      get { return variables; }
88    }
89    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
90      if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
91      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
92      return Model.GetEstimatedValues(dataset, rows);
93    }
94    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
95      return new RegressionSolution(this, problemData);
96    }
97    #endregion
98
99    internal void Split(IReadOnlyList<int> rows, M5Parameters m5Params) {
100      variables = m5Params.AllowedInputVariables.ToArray();
101      NumSamples = rows.Count;
102      Right = null;
103      Left = null;
104      Model = null;
105      SplitAttribute = null;
106      SplitValue = double.NaN;
107      string attr;
108      double splitValue;
109      IsLeaf = !m5Params.Spliter.Split(new RegressionProblemData(M5StaticUtilities.ReduceDataset(m5Params.Data, rows, variables, TargetVariable), variables, TargetVariable), m5Params.MinLeafSize, out attr, out splitValue);
110      if (IsLeaf) return;
111
112      //split Dataset
113      IReadOnlyList<int> leftRows, rightRows;
114      M5StaticUtilities.SplitRows(rows, m5Params.Data, attr, splitValue, out leftRows, out rightRows);
115
116      if (leftRows.Count < m5Params.MinLeafSize || rightRows.Count < m5Params.MinLeafSize) {
117        IsLeaf = true;
118        return;
119      }
120      SplitAttribute = attr;
121      SplitValue = splitValue;
122
123      //create subtrees
124      Left = CreateNode(this, m5Params);
125      Left.Split(leftRows, m5Params);
126      Right = CreateNode(this, m5Params);
127      Right.Split(rightRows, m5Params);
128    }
129
130    internal void ToLeaf() {
131      IsLeaf = true;
132      Right = null;
133      Left = null;
134    }
135
136    internal void BuildLeafModels(IReadOnlyList<int> rows, M5Parameters parameters, CancellationToken cancellationToken) {
137      if (!IsLeaf) {
138        IReadOnlyList<int> leftRows, rightRows;
139        M5StaticUtilities.SplitRows(rows, parameters.Data, SplitAttribute, SplitValue, out leftRows, out rightRows);
140        Left.BuildLeafModels(leftRows, parameters, cancellationToken);
141        Right.BuildLeafModels(rightRows, parameters, cancellationToken);
142        return;
143      }
144      int numP;
145      Model = M5StaticUtilities.BuildModel(rows, parameters, parameters.LeafModel, cancellationToken, out numP);
146    }
147
148    internal IEnumerable<M5NodeModel> EnumerateNodes() {
149      var queue = new Queue<M5NodeModel>();
150      queue.Enqueue(this);
151      while (queue.Count != 0) {
152        var cur = queue.Dequeue();
153        yield return cur;
154        if (cur.Left == null && cur.Right == null) continue;
155        if (cur.Left != null) queue.Enqueue(cur.Left);
156        if (cur.Right != null) queue.Enqueue(cur.Right);
157      }
158    }
159
160    #region Helpers
161    private double GetEstimatedValue(IDataset dataset, int row) {
162      if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
163      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
164      return Model.GetEstimatedValues(dataset, new[] {row}).First();
165    }
166    #endregion
167
168    [StorableClass]
169    private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel {
170      #region HLConstructors
171      [StorableConstructor]
172      private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { }
173      private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { }
174      public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { }
175      public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { }
176      public override IDeepCloneable Clone(Cloner cloner) {
177        return new ConfidenceM5NodeModel(this, cloner);
178      }
179      #endregion
180
181      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
182        return IsLeaf ? ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
183      }
184
185      private double GetEstimatedVariance(IDataset dataset, int row) {
186        if (!IsLeaf)
187          return ((IConfidenceRegressionModel)(dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
188        return ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, new[] {row}).First();
189      }
190
191      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
192        return new ConfidenceRegressionSolution(this, problemData);
193      }
194    }
195  }
196}
Note: See TracBrowser for help on using the repository browser.