Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/Utilities/RegressionTreeAnalyzer.cs @ 17120

Last change on this file since 17120 was 16855, checked in by gkronber, 5 years ago

#2847: moved M5 regression into a separate plugin as it depends on HL.DataAnalysis.Algorithms.Glmnet plugin

File size: 12.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Problems.DataAnalysis;
31using HEAL.Attic;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  public static class RegressionTreeAnalyzer {
35    private const string ConditionResultName = "Condition";
36    private const string CoverResultName = "Covered Instances";
37    private const string CoverageDiagramResultName = "Coverage";
38    private const string RuleModelResultName = "RuleModel";
39
40    public static Dictionary<string, int> GetRuleVariableFrequences(RegressionRuleSetModel ruleSetModel) {
41      var res = ruleSetModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
42      foreach (var rule in ruleSetModel.Rules)
43      foreach (var att in rule.SplitAttributes)
44        res[att]++;
45      return res;
46    }
47
48    public static Dictionary<string, int> GetTreeVariableFrequences(RegressionNodeTreeModel treeModel) {
49      var res = treeModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
50      var root = treeModel.Root;
51      foreach (var cur in root.EnumerateNodes().Where(x => !x.IsLeaf))
52        res[cur.SplitAttribute]++;
53      return res;
54    }
55
56    public static Result CreateLeafDepthHistogram(RegressionNodeTreeModel treeModel) {
57      var list = new List<int>();
58      GetLeafDepths(treeModel.Root, 0, list);
59      var row = new DataRow("Depths", "", list.Select(x => (double)x)) {
60        VisualProperties = {ChartType = DataRowVisualProperties.DataRowChartType.Histogram}
61      };
62      var hist = new DataTable("LeafDepths");
63      hist.Rows.Add(row);
64      return new Result(hist.Name, hist);
65    }
66
67    public static Result CreateRulesResult(RegressionRuleSetModel ruleSetModel, IRegressionProblemData pd, string resultName, bool displayModels) {
68      var res = new ResultCollection();
69      var i = 0;
70      foreach (var rule in ruleSetModel.Rules)
71        res.Add(new Result("Rule" + i++, CreateRulesResult(rule, pd, displayModels, out pd)));
72      return new Result(resultName, res);
73    }
74
75    public static IResult CreateCoverageDiagram(RegressionRuleSetModel setModel, IRegressionProblemData problemData) {
76      var res = new DataTable(CoverageDiagramResultName);
77      var training = CountCoverage(setModel, problemData.Dataset, problemData.TrainingIndices);
78      var test = CountCoverage(setModel, problemData.Dataset, problemData.TestIndices);
79      res.Rows.Add(new DataRow("Training", "", training));
80      res.Rows.Add(new DataRow("Test", "", test));
81
82      foreach (var row in res.Rows)
83        row.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Columns;
84      res.VisualProperties.XAxisMaximumFixedValue = training.Count + 1;
85      res.VisualProperties.XAxisMaximumAuto = false;
86      res.VisualProperties.XAxisMinimumFixedValue = 0;
87      res.VisualProperties.XAxisMinimumAuto = false;
88      res.VisualProperties.XAxisTitle = "Rule";
89      res.VisualProperties.YAxisTitle = "Covered Instances";
90
91      return new Result(CoverageDiagramResultName, res);
92    }
93
94    private static void GetLeafDepths(RegressionNodeModel n, int depth, ICollection<int> res) {
95      if (n == null) return;
96      if (n.Left == null && n.Right == null) res.Add(depth);
97      else {
98        GetLeafDepths(n.Left, depth + 1, res);
99        GetLeafDepths(n.Right, depth + 1, res);
100      }
101    }
102
103    private static IScope CreateRulesResult(RegressionRuleModel regressionRuleModel, IRegressionProblemData pd, bool displayModels, out IRegressionProblemData notCovered) {
104      var training = pd.TrainingIndices.Where(x => !regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
105      var test = pd.TestIndices.Where(x => !regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
106      if (training.Length > 0 || test.Length > 0) {
107        var data = new Dataset(pd.Dataset.DoubleVariables, pd.Dataset.DoubleVariables.Select(v => pd.Dataset.GetDoubleValues(v, training.Concat(test)).ToArray()));
108        notCovered = new RegressionProblemData(data, pd.AllowedInputVariables, pd.TargetVariable);
109        notCovered.TestPartition.Start = notCovered.TrainingPartition.End = training.Length;
110        notCovered.TestPartition.End = training.Length + test.Length;
111      } else notCovered = null;
112
113      var training2 = pd.TrainingIndices.Where(x => regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
114      var test2 = pd.TestIndices.Where(x => regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
115      var data2 = new Dataset(pd.Dataset.DoubleVariables, pd.Dataset.DoubleVariables.Select(v => pd.Dataset.GetDoubleValues(v, training2.Concat(test2)).ToArray()));
116      var covered = new RegressionProblemData(data2, pd.AllowedInputVariables, pd.TargetVariable);
117      covered.TestPartition.Start = covered.TrainingPartition.End = training2.Length;
118      covered.TestPartition.End = training2.Length + test2.Length;
119
120      var res2 = new Scope("RuleModels");
121      res2.Variables.Add(new Variable(ConditionResultName, new StringValue(regressionRuleModel.ToCompactString())));
122      res2.Variables.Add(new Variable(CoverResultName, new IntValue(pd.TrainingIndices.Count() - training.Length)));
123      if (displayModels)
124        res2.Variables.Add(new Variable(RuleModelResultName, regressionRuleModel.CreateRegressionSolution(covered)));
125      return res2;
126    }
127
128    private static IReadOnlyList<double> CountCoverage(RegressionRuleSetModel setModel, IDataset data, IEnumerable<int> rows) {
129      var rules = setModel.Rules.ToArray();
130      var res = new double[rules.Length];
131      foreach (var row in rows)
132        for (var i = 0; i < rules.Length; i++)
133          if (rules[i].Covers(data, row)) {
134            res[i]++;
135            break;
136          }
137      return res;
138    }
139
140    public static void AnalyzeNodes(RegressionNodeTreeModel tree, ResultCollection results, IRegressionProblemData pd) {
141      var dict = new Dictionary<int, RegressionNodeModel>();
142      var trainingLeafRows = new Dictionary<int, IReadOnlyList<int>>();
143      var testLeafRows = new Dictionary<int, IReadOnlyList<int>>();
144      var modelNumber = new IntValue(1);
145      var symtree = new SymbolicExpressionTree(MirrorTree(tree.Root, dict, trainingLeafRows, testLeafRows, modelNumber, pd.Dataset, pd.TrainingIndices.ToList(), pd.TestIndices.ToList()));
146      results.AddOrUpdateResult("DecisionTree", symtree);
147
148      if (dict.Count > 200) return;
149      var models = new Scope("NodeModels");
150      results.AddOrUpdateResult("NodeModels", models);
151      foreach (var m in dict.Keys.OrderBy(x => x))
152        models.Variables.Add(new Variable("Model " + m, dict[m].CreateRegressionSolution(Subselect(pd, trainingLeafRows[m], testLeafRows[m]))));
153    }
154
155    public static void PruningChart(RegressionNodeTreeModel tree, ComplexityPruning pruning, ResultCollection results) {
156      var nodes = new Queue<RegressionNodeModel>();
157      nodes.Enqueue(tree.Root);
158      var max = 0.0;
159      var strenghts = new SortedList<double, int>();
160      while (nodes.Count > 0) {
161        var n = nodes.Dequeue();
162
163        if (n.IsLeaf) {
164          max++;
165          continue;
166        }
167
168        if (!strenghts.ContainsKey(n.PruningStrength)) strenghts.Add(n.PruningStrength, 0);
169        strenghts[n.PruningStrength]++;
170        nodes.Enqueue(n.Left);
171        nodes.Enqueue(n.Right);
172      }
173      if (strenghts.Count == 0) return;
174
175      var plot = new ScatterPlot("Pruned Sizes", "") {
176        VisualProperties = {
177          XAxisTitle = "Pruning Strength",
178          YAxisTitle = "Tree Size",
179          XAxisMinimumAuto = false,
180          XAxisMinimumFixedValue = 0
181        }
182      };
183      var row = new ScatterPlotDataRow("TreeSizes", "", new List<Point2D<double>>());
184      row.Points.Add(new Point2D<double>(pruning.PruningStrength, max));
185
186      var fillerDots = new Queue<double>();
187      var minX = pruning.PruningStrength;
188      var maxX = strenghts.Last().Key;
189      var size = (maxX - minX) / 200;
190      for (var x = minX; x <= maxX; x += size) {
191        fillerDots.Enqueue(x);
192      }
193
194      foreach (var strenght in strenghts.Keys) {
195        while (fillerDots.Count > 0 && strenght > fillerDots.Peek())
196          row.Points.Add(new Point2D<double>(fillerDots.Dequeue(), max));
197        max -= strenghts[strenght];
198        row.Points.Add(new Point2D<double>(strenght, max));
199      }
200
201
202      row.VisualProperties.PointSize = 6;
203      plot.Rows.Add(row);
204      results.AddOrUpdateResult("PruningSizes", plot);
205    }
206
207
208    private static IRegressionProblemData Subselect(IRegressionProblemData data, IReadOnlyList<int> training, IReadOnlyList<int> test) {
209      var dataset = RegressionTreeUtilities.ReduceDataset(data.Dataset, training.Concat(test).ToList(), data.AllowedInputVariables.ToList(), data.TargetVariable);
210      var res = new RegressionProblemData(dataset, data.AllowedInputVariables, data.TargetVariable);
211      res.TrainingPartition.Start = 0;
212      res.TrainingPartition.End = training.Count;
213      res.TestPartition.Start = training.Count;
214      res.TestPartition.End = training.Count + test.Count;
215      return res;
216    }
217
218    private static SymbolicExpressionTreeNode MirrorTree(RegressionNodeModel regressionNode, IDictionary<int, RegressionNodeModel> dict,
219      IDictionary<int, IReadOnlyList<int>> trainingLeafRows,
220      IDictionary<int, IReadOnlyList<int>> testLeafRows,
221      IntValue nextId, IDataset data, IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows) {
222      if (regressionNode.IsLeaf) {
223        var i = nextId.Value++;
224        dict.Add(i, regressionNode);
225        trainingLeafRows.Add(i, trainingRows);
226        testLeafRows.Add(i, testRows);
227        return new SymbolicExpressionTreeNode(new TextSymbol("Model " + i + "\n(" + trainingRows.Count + "/" + testRows.Count + ")"));
228      }
229
230      var pftext = "\npf = " + regressionNode.PruningStrength.ToString("0.###");
231      var text = regressionNode.SplitAttribute + " <= " + regressionNode.SplitValue.ToString("0.###");
232      if (!double.IsNaN(regressionNode.PruningStrength)) text += pftext;
233
234      var textNode = new SymbolicExpressionTreeNode(new TextSymbol(text));
235      IReadOnlyList<int> lTrainingRows, rTrainingRows;
236      IReadOnlyList<int> lTestRows, rTestRows;
237      RegressionTreeUtilities.SplitRows(trainingRows, data, regressionNode.SplitAttribute, regressionNode.SplitValue, out lTrainingRows, out rTrainingRows);
238      RegressionTreeUtilities.SplitRows(testRows, data, regressionNode.SplitAttribute, regressionNode.SplitValue, out lTestRows, out rTestRows);
239
240      textNode.AddSubtree(MirrorTree(regressionNode.Left, dict, trainingLeafRows, testLeafRows, nextId, data, lTrainingRows, lTestRows));
241      textNode.AddSubtree(MirrorTree(regressionNode.Right, dict, trainingLeafRows, testLeafRows, nextId, data, rTrainingRows, rTestRows));
242
243      return textNode;
244    }
245
246
247    [StorableType("D5540C63-310B-4D6F-8A3D-6C1A08DE7F80")]
248    private sealed class TextSymbol : Symbol {
249      [StorableConstructor]
250      private TextSymbol(StorableConstructorFlag _) : base(_) { }
251      private TextSymbol(Symbol original, Cloner cloner) : base(original, cloner) { }
252      public TextSymbol(string name) : base(name, "") {
253        Name = name;
254      }
255      public override IDeepCloneable Clone(Cloner cloner) {
256        return new TextSymbol(this, cloner);
257      }
258      public override int MinimumArity {
259        get { return 0; }
260      }
261      public override int MaximumArity {
262        get { return int.MaxValue; }
263      }
264    }
265  }
266}
Note: See TracBrowser for help on using the repository browser.