source: branches/2847_M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Utilities/RegressionTreeAnalyzer.cs @ 16847

Last change on this file since 16847 was 16847, checked in by gkronber, 4 months ago

#2847: made some minor changes while reviewing

File size: 12.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Problems.DataAnalysis;
31using HEAL.Attic;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  public static class RegressionTreeAnalyzer {
35    private const string ConditionResultName = "Condition";
36    private const string CoverResultName = "Covered Instances";
37    private const string CoverageDiagramResultName = "Coverage";
38    private const string RuleModelResultName = "RuleModel";
39
40    public static Dictionary<string, int> GetRuleVariableFrequences(RegressionRuleSetModel ruleSetModel) {
41      var res = ruleSetModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
42      foreach (var rule in ruleSetModel.Rules)
43      foreach (var att in rule.SplitAttributes)
44        res[att]++;
45      return res;
46    }
47
48    public static Dictionary<string, int> GetTreeVariableFrequences(RegressionNodeTreeModel treeModel) {
49      var res = treeModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
50      var root = treeModel.Root;
51      foreach (var cur in root.EnumerateNodes().Where(x => !x.IsLeaf))
52        res[cur.SplitAttribute]++;
53      return res;
54    }
55
56    public static Result CreateLeafDepthHistogram(RegressionNodeTreeModel treeModel) {
57      var list = new List<int>();
58      GetLeafDepths(treeModel.Root, 0, list);
59      var row = new DataRow("Depths", "", list.Select(x => (double)x)) {
60        VisualProperties = {ChartType = DataRowVisualProperties.DataRowChartType.Histogram}
61      };
62      var hist = new DataTable("LeafDepths");
63      hist.Rows.Add(row);
64      return new Result(hist.Name, hist);
65    }
66
67    public static Result CreateRulesResult(RegressionRuleSetModel ruleSetModel, IRegressionProblemData pd, string resultName, bool displayModels) {
68      var res = new ResultCollection();
69      var i = 0;
70      foreach (var rule in ruleSetModel.Rules)
71        res.Add(new Result("Rule" + i++, CreateRulesResult(rule, pd, displayModels, out pd)));
72      return new Result(resultName, res);
73    }
74
75    public static IResult CreateCoverageDiagram(RegressionRuleSetModel setModel, IRegressionProblemData problemData) {
76      var res = new DataTable(CoverageDiagramResultName);
77      var training = CountCoverage(setModel, problemData.Dataset, problemData.TrainingIndices);
78      var test = CountCoverage(setModel, problemData.Dataset, problemData.TestIndices);
79      res.Rows.Add(new DataRow("Training", "", training));
80      res.Rows.Add(new DataRow("Test", "", test));
81
82      foreach (var row in res.Rows)
83        row.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Columns;
84      res.VisualProperties.XAxisMaximumFixedValue = training.Count + 1;
85      res.VisualProperties.XAxisMaximumAuto = false;
86      res.VisualProperties.XAxisMinimumFixedValue = 0;
87      res.VisualProperties.XAxisMinimumAuto = false;
88      res.VisualProperties.XAxisTitle = "Rule";
89      res.VisualProperties.YAxisTitle = "Covered Instances";
90
91      return new Result(CoverageDiagramResultName, res);
92    }
93
94    private static void GetLeafDepths(RegressionNodeModel n, int depth, ICollection<int> res) {
95      if (n == null) return;
96      if (n.Left == null && n.Right == null) res.Add(depth);
97      else {
98        GetLeafDepths(n.Left, depth + 1, res);
99        GetLeafDepths(n.Right, depth + 1, res);
100      }
101    }
102
103    private static IScope CreateRulesResult(RegressionRuleModel regressionRuleModel, IRegressionProblemData pd, bool displayModels, out IRegressionProblemData notCovered) {
104      var training = pd.TrainingIndices.Where(x => !regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
105      var test = pd.TestIndices.Where(x => !regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
106      var data = new Dataset(pd.Dataset.DoubleVariables, pd.Dataset.DoubleVariables.Select(v => pd.Dataset.GetDoubleValues(v, training.Concat(test)).ToArray()));
107      notCovered = new RegressionProblemData(data, pd.AllowedInputVariables, pd.TargetVariable);
108      notCovered.TestPartition.Start = notCovered.TrainingPartition.End = training.Length;
109      notCovered.TestPartition.End = training.Length + test.Length;
110
111      var training2 = pd.TrainingIndices.Where(x => regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
112      var test2 = pd.TestIndices.Where(x => regressionRuleModel.Covers(pd.Dataset, x)).ToArray();
113      var data2 = new Dataset(pd.Dataset.DoubleVariables, pd.Dataset.DoubleVariables.Select(v => pd.Dataset.GetDoubleValues(v, training2.Concat(test2)).ToArray()));
114      var covered = new RegressionProblemData(data2, pd.AllowedInputVariables, pd.TargetVariable);
115      covered.TestPartition.Start = covered.TrainingPartition.End = training2.Length;
116      covered.TestPartition.End = training2.Length + test2.Length;
117
118      var res2 = new Scope("RuleModels");
119      res2.Variables.Add(new Variable(ConditionResultName, new StringValue(regressionRuleModel.ToCompactString())));
120      res2.Variables.Add(new Variable(CoverResultName, new IntValue(pd.TrainingIndices.Count() - training.Length)));
121      if (displayModels)
122        res2.Variables.Add(new Variable(RuleModelResultName, regressionRuleModel.CreateRegressionSolution(covered)));
123      return res2;
124    }
125
126    private static IReadOnlyList<double> CountCoverage(RegressionRuleSetModel setModel, IDataset data, IEnumerable<int> rows) {
127      var rules = setModel.Rules.ToArray();
128      var res = new double[rules.Length];
129      foreach (var row in rows)
130        for (var i = 0; i < rules.Length; i++)
131          if (rules[i].Covers(data, row)) {
132            res[i]++;
133            break;
134          }
135      return res;
136    }
137
138    public static void AnalyzeNodes(RegressionNodeTreeModel tree, ResultCollection results, IRegressionProblemData pd) {
139      var dict = new Dictionary<int, RegressionNodeModel>();
140      var trainingLeafRows = new Dictionary<int, IReadOnlyList<int>>();
141      var testLeafRows = new Dictionary<int, IReadOnlyList<int>>();
142      var modelNumber = new IntValue(1);
143      var symtree = new SymbolicExpressionTree(MirrorTree(tree.Root, dict, trainingLeafRows, testLeafRows, modelNumber, pd.Dataset, pd.TrainingIndices.ToList(), pd.TestIndices.ToList()));
144      results.AddOrUpdateResult("DecisionTree", symtree);
145
146      if (dict.Count > 200) return;
147      var models = new Scope("NodeModels");
148      results.AddOrUpdateResult("NodeModels", models);
149      foreach (var m in dict.Keys.OrderBy(x => x))
150        models.Variables.Add(new Variable("Model " + m, dict[m].CreateRegressionSolution(Subselect(pd, trainingLeafRows[m], testLeafRows[m]))));
151    }
152
153    public static void PruningChart(RegressionNodeTreeModel tree, ComplexityPruning pruning, ResultCollection results) {
154      var nodes = new Queue<RegressionNodeModel>();
155      nodes.Enqueue(tree.Root);
156      var max = 0.0;
157      var strenghts = new SortedList<double, int>();
158      while (nodes.Count > 0) {
159        var n = nodes.Dequeue();
160
161        if (n.IsLeaf) {
162          max++;
163          continue;
164        }
165
166        if (!strenghts.ContainsKey(n.PruningStrength)) strenghts.Add(n.PruningStrength, 0);
167        strenghts[n.PruningStrength]++;
168        nodes.Enqueue(n.Left);
169        nodes.Enqueue(n.Right);
170      }
171      if (strenghts.Count == 0) return;
172
173      var plot = new ScatterPlot("Pruned Sizes", "") {
174        VisualProperties = {
175          XAxisTitle = "Pruning Strength",
176          YAxisTitle = "Tree Size",
177          XAxisMinimumAuto = false,
178          XAxisMinimumFixedValue = 0
179        }
180      };
181      var row = new ScatterPlotDataRow("TreeSizes", "", new List<Point2D<double>>());
182      row.Points.Add(new Point2D<double>(pruning.PruningStrength, max));
183
184      var fillerDots = new Queue<double>();
185      var minX = pruning.PruningStrength;
186      var maxX = strenghts.Last().Key;
187      var size = (maxX - minX) / 200;
188      for (var x = minX; x <= maxX; x += size) {
189        fillerDots.Enqueue(x);
190      }
191
192      foreach (var strenght in strenghts.Keys) {
193        while (fillerDots.Count > 0 && strenght > fillerDots.Peek())
194          row.Points.Add(new Point2D<double>(fillerDots.Dequeue(), max));
195        max -= strenghts[strenght];
196        row.Points.Add(new Point2D<double>(strenght, max));
197      }
198
199
200      row.VisualProperties.PointSize = 6;
201      plot.Rows.Add(row);
202      results.AddOrUpdateResult("PruningSizes", plot);
203    }
204
205
206    private static IRegressionProblemData Subselect(IRegressionProblemData data, IReadOnlyList<int> training, IReadOnlyList<int> test) {
207      var dataset = RegressionTreeUtilities.ReduceDataset(data.Dataset, training.Concat(test).ToList(), data.AllowedInputVariables.ToList(), data.TargetVariable);
208      var res = new RegressionProblemData(dataset, data.AllowedInputVariables, data.TargetVariable);
209      res.TrainingPartition.Start = 0;
210      res.TrainingPartition.End = training.Count;
211      res.TestPartition.Start = training.Count;
212      res.TestPartition.End = training.Count + test.Count;
213      return res;
214    }
215
216    private static SymbolicExpressionTreeNode MirrorTree(RegressionNodeModel regressionNode, IDictionary<int, RegressionNodeModel> dict,
217      IDictionary<int, IReadOnlyList<int>> trainingLeafRows,
218      IDictionary<int, IReadOnlyList<int>> testLeafRows,
219      IntValue nextId, IDataset data, IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows) {
220      if (regressionNode.IsLeaf) {
221        var i = nextId.Value++;
222        dict.Add(i, regressionNode);
223        trainingLeafRows.Add(i, trainingRows);
224        testLeafRows.Add(i, testRows);
225        return new SymbolicExpressionTreeNode(new TextSymbol("Model " + i + "\n(" + trainingRows.Count + "/" + testRows.Count + ")"));
226      }
227
228      var pftext = "\npf = " + regressionNode.PruningStrength.ToString("0.###");
229      var text = regressionNode.SplitAttribute + " <= " + regressionNode.SplitValue.ToString("0.###");
230      if (!double.IsNaN(regressionNode.PruningStrength)) text += pftext;
231
232      var textNode = new SymbolicExpressionTreeNode(new TextSymbol(text));
233      IReadOnlyList<int> lTrainingRows, rTrainingRows;
234      IReadOnlyList<int> lTestRows, rTestRows;
235      RegressionTreeUtilities.SplitRows(trainingRows, data, regressionNode.SplitAttribute, regressionNode.SplitValue, out lTrainingRows, out rTrainingRows);
236      RegressionTreeUtilities.SplitRows(testRows, data, regressionNode.SplitAttribute, regressionNode.SplitValue, out lTestRows, out rTestRows);
237
238      textNode.AddSubtree(MirrorTree(regressionNode.Left, dict, trainingLeafRows, testLeafRows, nextId, data, lTrainingRows, lTestRows));
239      textNode.AddSubtree(MirrorTree(regressionNode.Right, dict, trainingLeafRows, testLeafRows, nextId, data, rTrainingRows, rTestRows));
240
241      return textNode;
242    }
243
244
245    [StorableType("D5540C63-310B-4D6F-8A3D-6C1A08DE7F80")]
246    private class TextSymbol : Symbol {
247      [StorableConstructor]
248      private TextSymbol(StorableConstructorFlag _) : base(_) { }
249      private TextSymbol(Symbol original, Cloner cloner) : base(original, cloner) { }
250      public TextSymbol(string name) : base(name, "") {
251        Name = name;
252      }
253      public override IDeepCloneable Clone(Cloner cloner) {
254        return new TextSymbol(this, cloner);
255      }
256      public override int MinimumArity {
257        get { return 0; }
258      }
259      public override int MaximumArity {
260        get { return int.MaxValue; }
261      }
262    }
263  }
264}
Note: See TracBrowser for help on using the repository browser.