source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/RegressionNodeModel.cs @ 15830

Last change on this file since 15830 was 15830, checked in by bwerth, 19 months ago

#2847 adapted project to new rep structure; major changes to interfaces; restructures splitting and pruning

File size: 7.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableClass]
32  public class RegressionNodeModel : RegressionModel {
33    #region Properties
34    public double PruningStrength = double.NaN;
35
36    [Storable]
37    private IReadOnlyList<string> Variables {
38      get {
39        if (IsLeaf && Model == null) return new List<string>();
40        if (IsLeaf) return Model.VariablesUsedForPrediction.ToList();
41        var set = new HashSet<string> {SplitAttribute};
42        var vl = Left.Variables;
43        var vr = Right.Variables;
44        for (var i = 0; i < vl.Count; i++) set.Add(vl[i]);
45        for (var i = 0; i < vr.Count; i++) set.Add(vr[i]);
46        return set.ToList();
47      }
48    }
49    [Storable]
50    internal int NumSamples { get; private set; }
51    [Storable]
52    internal bool IsLeaf { get; private set; }
53    [Storable]
54    internal IRegressionModel Model { get; private set; }
55
56    [Storable]
57    public string SplitAttribute { get; private set; }
58    [Storable]
59    public double SplitValue { get; private set; }
60    [Storable]
61    public RegressionNodeModel Left { get; private set; }
62    [Storable]
63    public RegressionNodeModel Right { get; private set; }
64    [Storable]
65    public RegressionNodeModel Parent { get; private set; }
66    #endregion
67
68    #region HLConstructors
69    [StorableConstructor]
70    protected RegressionNodeModel(bool deserializing) : base(deserializing) { }
71    protected RegressionNodeModel(RegressionNodeModel original, Cloner cloner) : base(original, cloner) {
72      IsLeaf = original.IsLeaf;
73      Model = cloner.Clone(original.Model);
74      SplitValue = original.SplitValue;
75      SplitAttribute = original.SplitAttribute;
76      Left = cloner.Clone(original.Left);
77      Right = cloner.Clone(original.Right);
78      Parent = cloner.Clone(original.Parent);
79      NumSamples = original.NumSamples;
80    }
81    private RegressionNodeModel(string targetAttr) : base(targetAttr) {
82      IsLeaf = true;
83    }
84    private RegressionNodeModel(RegressionNodeModel parent) : this(parent.TargetVariable) {
85      Parent = parent;
86      IsLeaf = true;
87    }
88    public override IDeepCloneable Clone(Cloner cloner) {
89      return new RegressionNodeModel(this, cloner);
90    }
91    public static RegressionNodeModel CreateNode(string targetAttr, RegressionTreeParameters regressionTreeParams) {
92      return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionNodeModel(targetAttr) : new RegressionNodeModel(targetAttr);
93    }
94    private static RegressionNodeModel CreateNode(RegressionNodeModel parent, RegressionTreeParameters regressionTreeParams) {
95      return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionNodeModel(parent) : new RegressionNodeModel(parent);
96    }
97    #endregion
98
99    #region RegressionModel
100    public override IEnumerable<string> VariablesUsedForPrediction {
101      get { return Variables; }
102    }
103    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
104      if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
105      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
106      return Model.GetEstimatedValues(dataset, rows);
107    }
108    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
109      return new RegressionSolution(this, problemData);
110    }
111    #endregion
112
113    internal void Split(RegressionTreeParameters regressionTreeParams, string splitAttribute, double splitValue, int numSamples) {
114      NumSamples = numSamples;
115      SplitAttribute = splitAttribute;
116      SplitValue = splitValue;
117      Left = CreateNode(this, regressionTreeParams);
118      Right = CreateNode(this, regressionTreeParams);
119      IsLeaf = false;
120    }
121
122    internal void ToLeaf() {
123      IsLeaf = true;
124      Right = null;
125      Left = null;
126    }
127
128    internal void SetLeafModel(IRegressionModel model) {
129      Model = model;
130    }
131
132    internal IEnumerable<RegressionNodeModel> EnumerateNodes() {
133      var queue = new Queue<RegressionNodeModel>();
134      queue.Enqueue(this);
135      while (queue.Count != 0) {
136        var cur = queue.Dequeue();
137        yield return cur;
138        if (cur.Left == null && cur.Right == null) continue;
139        if (cur.Left != null) queue.Enqueue(cur.Left);
140        if (cur.Right != null) queue.Enqueue(cur.Right);
141      }
142    }
143
144    #region Helpers
145    private double GetEstimatedValue(IDataset dataset, int row) {
146      if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
147      if (Model == null) throw new NotSupportedException("The model has not been built correctly");
148      return Model.GetEstimatedValues(dataset, new[] {row}).First();
149    }
150    #endregion
151
152    [StorableClass]
153    private sealed class ConfidenceRegressionNodeModel : RegressionNodeModel, IConfidenceRegressionModel {
154      #region HLConstructors
155      [StorableConstructor]
156      private ConfidenceRegressionNodeModel(bool deserializing) : base(deserializing) { }
157      private ConfidenceRegressionNodeModel(ConfidenceRegressionNodeModel original, Cloner cloner) : base(original, cloner) { }
158      public ConfidenceRegressionNodeModel(string targetAttr) : base(targetAttr) { }
159      public ConfidenceRegressionNodeModel(RegressionNodeModel parent) : base(parent) { }
160      public override IDeepCloneable Clone(Cloner cloner) {
161        return new ConfidenceRegressionNodeModel(this, cloner);
162      }
163      #endregion
164
165      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
166        return IsLeaf ? ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
167      }
168
169      private double GetEstimatedVariance(IDataset dataset, int row) {
170        if (!IsLeaf)
171          return ((IConfidenceRegressionModel)(dataset.GetDoubleValue(SplitAttribute, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
172        return ((IConfidenceRegressionModel)Model).GetEstimatedVariances(dataset, new[] {row}).First();
173      }
174
175      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
176        return new ConfidenceRegressionSolution(this, problemData);
177      }
178    }
179  }
180}
Note: See TracBrowser for help on using the repository browser.