1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using System;
|
---|
23 | using System.Collections.Generic;
|
---|
24 | using System.Linq;
|
---|
25 | using System.Threading;
|
---|
26 | using HeuristicLab.Common;
|
---|
27 | using HeuristicLab.Core;
|
---|
28 | using HeuristicLab.Data;
|
---|
29 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
30 | using HeuristicLab.Problems.DataAnalysis;
|
---|
31 |
|
---|
32 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
33 | [StorableClass]
|
---|
34 | internal class M5NodeModel : RegressionModel {
|
---|
35 | #region Properties
|
---|
36 | [Storable]
|
---|
37 | internal bool IsLeaf { get; private set; }
|
---|
38 | [Storable]
|
---|
39 | internal IRegressionModel NodeModel { get; private set; }
|
---|
40 | [Storable]
|
---|
41 | internal string SplitAttr { get; private set; }
|
---|
42 | [Storable]
|
---|
43 | internal double SplitValue { get; private set; }
|
---|
44 | [Storable]
|
---|
45 | internal M5NodeModel Left { get; private set; }
|
---|
46 | [Storable]
|
---|
47 | internal M5NodeModel Right { get; private set; }
|
---|
48 | [Storable]
|
---|
49 | internal M5NodeModel Parent { get; set; }
|
---|
50 | [Storable]
|
---|
51 | internal int NumSamples { get; private set; }
|
---|
52 | [Storable]
|
---|
53 | internal int NumParam { get; set; }
|
---|
54 | [Storable]
|
---|
55 | internal int NodeModelParams { get; set; }
|
---|
56 | [Storable]
|
---|
57 | private IReadOnlyList<string> Variables { get; set; }
|
---|
58 | #endregion
|
---|
59 |
|
---|
60 | #region HLConstructors
|
---|
61 | [StorableConstructor]
|
---|
62 | protected M5NodeModel(bool deserializing) : base(deserializing) { }
|
---|
63 | protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
|
---|
64 | IsLeaf = original.IsLeaf;
|
---|
65 | NodeModel = cloner.Clone(original.NodeModel);
|
---|
66 | SplitValue = original.SplitValue;
|
---|
67 | SplitAttr = original.SplitAttr;
|
---|
68 | Left = cloner.Clone(original.Left);
|
---|
69 | Right = cloner.Clone(original.Right);
|
---|
70 | Parent = cloner.Clone(original.Parent);
|
---|
71 | NumParam = original.NumParam;
|
---|
72 | NumSamples = original.NumSamples;
|
---|
73 | Variables = original.Variables != null ? original.Variables.ToList() : null;
|
---|
74 | }
|
---|
75 | protected M5NodeModel(string targetAttr) : base(targetAttr) { }
|
---|
76 | protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) {
|
---|
77 | Parent = parent;
|
---|
78 | }
|
---|
79 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
80 | return new M5NodeModel(this, cloner);
|
---|
81 | }
|
---|
82 | public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) {
|
---|
83 | return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
|
---|
84 | }
|
---|
85 | private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) {
|
---|
86 | return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
|
---|
87 | }
|
---|
88 | #endregion
|
---|
89 |
|
---|
90 | #region RegressionModel
|
---|
91 | public override IEnumerable<string> VariablesUsedForPrediction {
|
---|
92 | get { return Variables; }
|
---|
93 | }
|
---|
94 | public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
|
---|
95 | if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
|
---|
96 | if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
|
---|
97 | return NodeModel.GetEstimatedValues(dataset, rows);
|
---|
98 | }
|
---|
99 | public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
100 | return new RegressionSolution(this, problemData);
|
---|
101 | }
|
---|
102 | #endregion
|
---|
103 |
|
---|
104 | internal void Split(IReadOnlyList<int> rows, M5CreationParameters m5CreationParams, double globalStdDev) {
|
---|
105 | Variables = m5CreationParams.AllowedInputVariables.ToArray();
|
---|
106 | NumSamples = rows.Count;
|
---|
107 | Right = null;
|
---|
108 | Left = null;
|
---|
109 | NodeModel = null;
|
---|
110 | SplitAttr = null;
|
---|
111 | SplitValue = double.NaN;
|
---|
112 | string attr;
|
---|
113 | double splitValue;
|
---|
114 | //IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction;
|
---|
115 | //if (IsLeaf) return;
|
---|
116 | IsLeaf = !m5CreationParams.Split.Split(new RegressionProblemData(ReduceDataset(m5CreationParams.Data, rows), Variables, TargetVariable), m5CreationParams.MinLeafSize, out attr, out splitValue);
|
---|
117 | if (IsLeaf) return;
|
---|
118 |
|
---|
119 | //split Dataset
|
---|
120 | IReadOnlyList<int> leftRows, rightRows;
|
---|
121 | SplitRows(rows, m5CreationParams.Data, attr, splitValue, out leftRows, out rightRows);
|
---|
122 |
|
---|
123 | if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) {
|
---|
124 | IsLeaf = true;
|
---|
125 | return;
|
---|
126 | }
|
---|
127 | SplitAttr = attr;
|
---|
128 | SplitValue = splitValue;
|
---|
129 |
|
---|
130 | //create subtrees
|
---|
131 | Left = CreateNode(this, m5CreationParams);
|
---|
132 | Left.Split(leftRows, m5CreationParams, globalStdDev);
|
---|
133 | Right = CreateNode(this, m5CreationParams);
|
---|
134 | Right.Split(rightRows, m5CreationParams, globalStdDev);
|
---|
135 | }
|
---|
136 |
|
---|
137 | internal bool Prune(IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) {
|
---|
138 | if (IsLeaf) {
|
---|
139 | BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
|
---|
140 | NumParam = NodeModelParams;
|
---|
141 | return true;
|
---|
142 | }
|
---|
143 | //split training & holdout data
|
---|
144 | IReadOnlyList<int> leftTest, rightTest;
|
---|
145 | SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest);
|
---|
146 | IReadOnlyList<int> leftTraining, rightTraining;
|
---|
147 | SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining);
|
---|
148 |
|
---|
149 | //prune children frist
|
---|
150 | var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev);
|
---|
151 | var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev);
|
---|
152 | NumParam = Left.NumParam + Right.NumParam + 1;
|
---|
153 |
|
---|
154 | //TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes)
|
---|
155 | if (!lpruned && !rpruned) return false;
|
---|
156 |
|
---|
157 | BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
|
---|
158 |
|
---|
159 | //check if children will be pruned
|
---|
160 | if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false;
|
---|
161 |
|
---|
162 | //convert to leafNode
|
---|
163 | ((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1;
|
---|
164 | IsLeaf = true;
|
---|
165 | Right = null;
|
---|
166 | Left = null;
|
---|
167 | NumParam = NodeModelParams;
|
---|
168 | return true;
|
---|
169 | }
|
---|
170 |
|
---|
171 | internal void InstallModels(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
|
---|
172 | if (!IsLeaf) {
|
---|
173 | IReadOnlyList<int> leftRows, rightRows;
|
---|
174 | SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows);
|
---|
175 | Left.InstallModels(leftRows, random, data, leafType, cancellation);
|
---|
176 | Right.InstallModels(rightRows, random, data, leafType, cancellation);
|
---|
177 | return;
|
---|
178 | }
|
---|
179 | BuildModel(rows, data, random, leafType, cancellation);
|
---|
180 | }
|
---|
181 |
|
---|
182 | internal IEnumerable<M5NodeModel> EnumerateNodes() {
|
---|
183 | var queue = new Queue<M5NodeModel>();
|
---|
184 | queue.Enqueue(this);
|
---|
185 | while (queue.Count != 0) {
|
---|
186 | var cur = queue.Dequeue();
|
---|
187 | yield return cur;
|
---|
188 | if (cur.Left == null && cur.Right == null) continue;
|
---|
189 | if (cur.Left != null) queue.Enqueue(cur.Left);
|
---|
190 | if (cur.Right != null) queue.Enqueue(cur.Right);
|
---|
191 | }
|
---|
192 | }
|
---|
193 |
|
---|
194 | internal void ToRuleNode() {
|
---|
195 | Parent = null;
|
---|
196 | }
|
---|
197 |
|
---|
198 | #region Helpers
|
---|
199 | private double GetEstimatedValue(IDataset dataset, int row) {
|
---|
200 | if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
|
---|
201 | if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
|
---|
202 | return NodeModel.GetEstimatedValues(dataset, new[] {row}).First();
|
---|
203 | }
|
---|
204 |
|
---|
205 | private void BuildModel(IReadOnlyList<int> rows, IDataset data, IRandom random, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
|
---|
206 | var reducedData = ReduceDataset(data, rows);
|
---|
207 | var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
|
---|
208 | pd.TrainingPartition.Start = 0;
|
---|
209 | pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
|
---|
210 |
|
---|
211 | int noparams;
|
---|
212 | NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams);
|
---|
213 | NodeModelParams = noparams;
|
---|
214 | cancellation.ThrowIfCancellationRequested();
|
---|
215 | }
|
---|
216 |
|
---|
217 | private IDataset ReduceDataset(IDataset data, IReadOnlyList<int> rows) {
|
---|
218 | return new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
|
---|
219 | }
|
---|
220 |
|
---|
221 | private static void SplitRows(IReadOnlyList<int> rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList<int> leftRows, out IReadOnlyList<int> rightRows) {
|
---|
222 | var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray();
|
---|
223 | leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList();
|
---|
224 | rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList();
|
---|
225 | }
|
---|
226 | #endregion
|
---|
227 |
|
---|
228 | [StorableClass]
|
---|
229 | private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel {
|
---|
230 | #region HLConstructors
|
---|
231 | [StorableConstructor]
|
---|
232 | private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { }
|
---|
233 | private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { }
|
---|
234 | public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { }
|
---|
235 | public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { }
|
---|
236 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
237 | return new ConfidenceM5NodeModel(this, cloner);
|
---|
238 | }
|
---|
239 | #endregion
|
---|
240 |
|
---|
241 | public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
|
---|
242 | return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
|
---|
243 | }
|
---|
244 |
|
---|
245 | private double GetEstimatedVariance(IDataset dataset, int row) {
|
---|
246 | if (!IsLeaf)
|
---|
247 | return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
|
---|
248 | return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First();
|
---|
249 | }
|
---|
250 |
|
---|
251 | public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
252 | return new ConfidenceRegressionSolution(this, problemData);
|
---|
253 | }
|
---|
254 | }
|
---|
255 | }
|
---|
256 | } |
---|