source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Spliting/SplitterBase.cs @ 15830

Last change on this file since 15830 was 15830, checked in by bwerth, 16 months ago

#2847 adapted project to new rep structure; major changes to interfaces; restructures splitting and pruning

File size: 6.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item("SplitterBase", "A split selector that uses the ratio between Variances^(1/Order) to determine good splits")]
34  public abstract class SplitterBase : ParameterizedNamedItem, ISplitter {
35    public const string SplittingStateVariableName = "RuleSetState";
36
37    #region Constructors & Cloning
38    [StorableConstructor]
39    protected SplitterBase(bool deserializing) { }
40    protected SplitterBase(SplitterBase original, Cloner cloner) : base(original, cloner) { }
41    public SplitterBase() { }
42    #endregion
43
44    #region ISplitType
45    public void Initialize(IScope states) {
46      states.Variables.Add(new Variable(SplittingStateVariableName, new SplittingState()));
47    }
48    public void Split(RegressionNodeTreeModel tree, IReadOnlyList<int> trainingRows, IScope stateScope, CancellationToken cancellationToken) {
49      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[M5Regression.RegressionTreeParameterVariableName].Value;
50      var splittingState = (SplittingState)stateScope.Variables[SplittingStateVariableName].Value;
51      var variables = regressionTreeParams.AllowedInputVariables.ToArray();
52      var target = regressionTreeParams.TargetVariable;
53
54      if (splittingState.Code <= 0) {
55        splittingState.nodeQueue.Enqueue(tree.Root);
56        splittingState.trainingRowsQueue.Enqueue(trainingRows);
57        splittingState.Code = 1;
58      }
59      while (splittingState.nodeQueue.Count != 0) {
60        var n = splittingState.nodeQueue.Dequeue();
61        var rows = splittingState.trainingRowsQueue.Dequeue();
62
63        string attr;
64        double splitValue;
65        var isLeaf = !DecideSplit(new RegressionProblemData(RegressionTreeUtilities.ReduceDataset(regressionTreeParams.Data, rows, variables, target), variables, target), regressionTreeParams.MinLeafSize, out attr, out splitValue);
66        if (isLeaf) continue;
67
68        IReadOnlyList<int> leftRows, rightRows;
69        RegressionTreeUtilities.SplitRows(rows, regressionTreeParams.Data, attr, splitValue, out leftRows, out rightRows);
70        n.Split(regressionTreeParams, attr, splitValue, rows.Count);
71
72        splittingState.nodeQueue.Enqueue(n.Left);
73        splittingState.nodeQueue.Enqueue(n.Right);
74        splittingState.trainingRowsQueue.Enqueue(leftRows);
75        splittingState.trainingRowsQueue.Enqueue(rightRows);
76        cancellationToken.ThrowIfCancellationRequested();
77      }
78    }
79
80    protected virtual bool DecideSplit(IRegressionProblemData splitData, int minLeafSize, out string splitAttr, out double splitValue) {
81      var bestPos = 0;
82      var bestImpurity = double.MinValue;
83      var bestSplitValue = 0.0;
84      var bestSplitAttr = string.Empty;
85      splitAttr = bestSplitAttr;
86      splitValue = bestSplitValue;
87      if (splitData.Dataset.Rows < minLeafSize) return false;
88
89      //find best Attribute for the Splitter
90      foreach (var attr in splitData.AllowedInputVariables) {
91        int pos;
92        double impurity, sValue;
93        var sortedData = splitData.Dataset.GetDoubleValues(attr).Zip(splitData.Dataset.GetDoubleValues(splitData.TargetVariable), Tuple.Create).OrderBy(x => x.Item1).ToArray();
94        AttributeSplit(sortedData.Select(x => x.Item1).ToArray(), sortedData.Select(x => x.Item2).ToArray(), minLeafSize, out pos, out impurity, out sValue);
95        if (!(bestImpurity < impurity)) continue;
96        bestImpurity = impurity;
97        bestPos = pos;
98        bestSplitValue = sValue;
99        bestSplitAttr = attr;
100      }
101
102      splitAttr = bestSplitAttr;
103      splitValue = bestSplitValue;
104      //if no suitable split exists => leafNode
105      return bestPos + 1 >= minLeafSize && bestPos <= splitData.Dataset.Rows - minLeafSize;
106    }
107
108    protected abstract void AttributeSplit(IReadOnlyList<double> attValues, IReadOnlyList<double> targetValues, int minLeafSize, out int position, out double maxImpurity, out double splitValue);
109    #endregion
110
111    [StorableClass]
112    public class SplittingState : Item {
113      [Storable]
114      public Queue<RegressionNodeModel> nodeQueue = new Queue<RegressionNodeModel>();
115      [Storable]
116      public Queue<IReadOnlyList<int>> trainingRowsQueue = new Queue<IReadOnlyList<int>>();
117
118      //State.Code values denote the current action (for pausing)
119      //0...nothing has been done;
120      //1...splitting nodes;
121      [Storable]
122      public int Code = 0;
123
124      #region HLConstructors & Cloning
125      [StorableConstructor]
126      protected SplittingState(bool deserializing) : base(deserializing) { }
127      protected SplittingState(SplittingState original, Cloner cloner) : base(original, cloner) {
128        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
129        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
130        Code = original.Code;
131      }
132      public SplittingState() { }
133      public override IDeepCloneable Clone(Cloner cloner) {
134        return new SplittingState(this, cloner);
135      }
136      #endregion
137    }
138  }
139}
Note: See TracBrowser for help on using the repository browser.