source: branches/2847_M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/Spliting/SplitterBase.cs @ 16847

Last change on this file since 16847 was 16847, checked in by gkronber, 2 months ago

#2847: made some minor changes while reviewing

File size: 6.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Problems.DataAnalysis;
29using HEAL.Attic;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableType("22DCCF28-8943-4622-BBD3-B2AB04F28C36")]
33  [Item("SplitterBase", "Abstract base class for splitters")]
34  public abstract class SplitterBase : ParameterizedNamedItem, ISplitter {
35    public const string SplittingStateVariableName = "RuleSetState";
36
37    #region Constructors & Cloning
38    [StorableConstructor]
39    protected SplitterBase(StorableConstructorFlag _) { }
40    protected SplitterBase(SplitterBase original, Cloner cloner) : base(original, cloner) { }
41    public SplitterBase() { }
42    #endregion
43
44    #region ISplitType
45    public void Initialize(IScope states) {
46      states.Variables.Add(new Variable(SplittingStateVariableName, new SplittingState()));
47    }
48
49    public void Split(RegressionNodeTreeModel tree, IReadOnlyList<int> trainingRows, IScope stateScope, CancellationToken cancellationToken) {
50      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[M5Regression.RegressionTreeParameterVariableName].Value;
51      var splittingState = (SplittingState)stateScope.Variables[SplittingStateVariableName].Value;
52      var variables = regressionTreeParams.AllowedInputVariables.ToArray();
53      var target = regressionTreeParams.TargetVariable;
54
55      if (splittingState.Code <= 0) {
56        splittingState.nodeQueue.Enqueue(tree.Root);
57        splittingState.trainingRowsQueue.Enqueue(trainingRows);
58        splittingState.Code = 1;
59      }
60      while (splittingState.nodeQueue.Count != 0) {
61        var n = splittingState.nodeQueue.Dequeue();
62        var rows = splittingState.trainingRowsQueue.Dequeue();
63
64        string attr;
65        double splitValue;
66        var isLeaf = !DecideSplit(new RegressionProblemData(RegressionTreeUtilities.ReduceDataset(regressionTreeParams.Data, rows, variables, target), variables, target), regressionTreeParams.MinLeafSize, out attr, out splitValue);
67        if (isLeaf) continue;
68
69        IReadOnlyList<int> leftRows, rightRows;
70        RegressionTreeUtilities.SplitRows(rows, regressionTreeParams.Data, attr, splitValue, out leftRows, out rightRows);
71        n.Split(regressionTreeParams, attr, splitValue, rows.Count);
72
73        splittingState.nodeQueue.Enqueue(n.Left);
74        splittingState.nodeQueue.Enqueue(n.Right);
75        splittingState.trainingRowsQueue.Enqueue(leftRows);
76        splittingState.trainingRowsQueue.Enqueue(rightRows);
77        cancellationToken.ThrowIfCancellationRequested();
78      }
79    }
80
81    protected virtual bool DecideSplit(IRegressionProblemData splitData, int minLeafSize, out string splitAttr, out double splitValue) {
82      var bestPos = 0;
83      var bestImpurity = double.MinValue;
84      var bestSplitValue = 0.0;
85      var bestSplitAttr = string.Empty;
86      splitAttr = bestSplitAttr;
87      splitValue = bestSplitValue;
88      if (splitData.Dataset.Rows < minLeafSize) return false;
89
90      // find best attribute for the splitter
91      foreach (var attr in splitData.AllowedInputVariables) {
92        int pos;
93        double impurity, sValue;
94        var sortedData = splitData.Dataset.GetDoubleValues(attr).Zip(splitData.Dataset.GetDoubleValues(splitData.TargetVariable), Tuple.Create).OrderBy(x => x.Item1).ToArray();
95        AttributeSplit(sortedData.Select(x => x.Item1).ToArray(), sortedData.Select(x => x.Item2).ToArray(), minLeafSize, out pos, out impurity, out sValue);
96        if (!(bestImpurity < impurity)) continue;
97        bestImpurity = impurity;
98        bestPos = pos;
99        bestSplitValue = sValue;
100        bestSplitAttr = attr;
101      }
102
103      splitAttr = bestSplitAttr;
104      splitValue = bestSplitValue;
105      //if no suitable split exists => leafNode
106      return bestPos + 1 >= minLeafSize && bestPos <= splitData.Dataset.Rows - minLeafSize;
107    }
108
109    protected abstract void AttributeSplit(IReadOnlyList<double> attValues, IReadOnlyList<double> targetValues, int minLeafSize, out int position, out double maxImpurity, out double splitValue);
110    #endregion
111
112    [StorableType("BC1149FD-370E-4F3A-92F5-6E519736D09A")]
113    public class SplittingState : Item {
114      [Storable]
115      public Queue<RegressionNodeModel> nodeQueue = new Queue<RegressionNodeModel>();
116      [Storable]
117      public Queue<IReadOnlyList<int>> trainingRowsQueue = new Queue<IReadOnlyList<int>>();
118
119      //State.Code values denote the current action (for pausing)
120      //0...nothing has been done;
121      //1...splitting nodes;
122      [Storable]
123      public int Code = 0;
124
125      #region HLConstructors & Cloning
126      [StorableConstructor]
127      protected SplittingState(StorableConstructorFlag _) : base(_) { }
128      protected SplittingState(SplittingState original, Cloner cloner) : base(original, cloner) {
129        nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
130        trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
131        Code = original.Code;
132      }
133      public SplittingState() { }
134      public override IDeepCloneable Clone(Cloner cloner) {
135        return new SplittingState(this, cloner);
136      }
137      #endregion
138    }
139  }
140}
Note: See TracBrowser for help on using the repository browser.