Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/MetaModels/RegressionRuleSetModel.cs @ 18183

Last change on this file since 18183 was 17209, checked in by gkronber, 5 years ago

#2994: merged r17132:17198 from trunk to branch

File size: 8.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Problems.DataAnalysis;
31using HEAL.Attic;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [StorableType("7B4D9AE9-0456-4029-80A6-CCB5E33CE356")]
35  public class RegressionRuleSetModel : RegressionModel, IDecisionTreeModel {
36    private const string NumRulesResultName = "Number of rules";
37    private const string CoveredInstancesResultName = "Covered instances";
38    public const string RuleSetStateVariableName = "RuleSetState";
39
40    #region Properties
41    [Storable]
42    internal List<RegressionRuleModel> Rules { get; private set; }
43    #endregion
44
45    #region HLConstructors & Cloning
46    [StorableConstructor]
47    protected RegressionRuleSetModel(StorableConstructorFlag _) : base(_) { }
48    protected RegressionRuleSetModel(RegressionRuleSetModel original, Cloner cloner) : base(original, cloner) {
49      if (original.Rules != null) Rules = original.Rules.Select(cloner.Clone).ToList();
50    }
51    protected RegressionRuleSetModel(string targetVariable) : base(targetVariable) { }
52    public override IDeepCloneable Clone(Cloner cloner) {
53      return new RegressionRuleSetModel(this, cloner);
54    }
55    #endregion
56
57    internal static RegressionRuleSetModel CreateRuleModel(string targetAttr, RegressionTreeParameters regressionTreeParams) {
58      return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionRuleSetModel(targetAttr) : new RegressionRuleSetModel(targetAttr);
59    }
60
61    #region RegressionModel
62    public override IEnumerable<string> VariablesUsedForPrediction {
63      get {
64        var f = Rules.FirstOrDefault();
65        return f != null ? (f.VariablesUsedForPrediction ?? new List<string>()) : new List<string>();
66      }
67    }
68    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
69      if (Rules == null) throw new NotSupportedException("The model has not been built yet");
70      return rows.Select(row => GetEstimatedValue(dataset, row));
71    }
72    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
73      return new RegressionSolution(this, problemData);
74    }
75    #endregion
76
77    #region IDecisionTreeModel
78    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
79      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
80      var ruleSetState = (RuleSetState)stateScope.Variables[RuleSetStateVariableName].Value;
81
82      if (ruleSetState.Code <= 0) {
83        ruleSetState.Rules.Clear();
84        ruleSetState.TrainingRows = trainingRows;
85        ruleSetState.PruningRows = pruningRows;
86        ruleSetState.Code = 1;
87      }
88
89      do {
90        var tempRule = RegressionRuleModel.CreateRuleModel(regressionTreeParams.TargetVariable, regressionTreeParams);
91        cancellationToken.ThrowIfCancellationRequested();
92
93        if (!results.ContainsKey(NumRulesResultName)) results.Add(new Result(NumRulesResultName, new IntValue(0)));
94        if (!results.ContainsKey(CoveredInstancesResultName)) results.Add(new Result(CoveredInstancesResultName, new IntValue(0)));
95
96        var t1 = ruleSetState.TrainingRows.Count;
97        tempRule.Build(ruleSetState.TrainingRows, ruleSetState.PruningRows, stateScope, results, cancellationToken);
98        ruleSetState.TrainingRows = ruleSetState.TrainingRows.Where(i => !tempRule.Covers(regressionTreeParams.Data, i)).ToArray();
99        ruleSetState.PruningRows = ruleSetState.PruningRows.Where(i => !tempRule.Covers(regressionTreeParams.Data, i)).ToArray();
100        ruleSetState.Rules.Add(tempRule);
101        ((IntValue)results[NumRulesResultName].Value).Value++;
102        ((IntValue)results[CoveredInstancesResultName].Value).Value += t1 - ruleSetState.TrainingRows.Count;
103      }
104      while (ruleSetState.TrainingRows.Count > 0);
105      Rules = ruleSetState.Rules;
106    }
107    public void Update(IReadOnlyList<int> rows, IScope stateScope, CancellationToken cancellationToken) {
108      foreach (var rule in Rules) rule.Update(rows, stateScope, cancellationToken);
109    }
110    public static void Initialize(IScope stateScope) {
111      stateScope.Variables.Add(new Variable(RuleSetStateVariableName, new RuleSetState()));
112    }
113    #endregion
114
115    #region Helpers
116    private double GetEstimatedValue(IDataset dataset, int row) {
117      foreach (var rule in Rules) {
118        if (rule.Covers(dataset, row))
119          return rule.GetEstimatedValues(dataset, row.ToEnumerable()).Single();
120      }
121      throw new ArgumentException("Instance is not covered by any rule");
122    }
123    #endregion
124
125    [StorableType("E114F3C9-3C1F-443D-8270-0E10CE12F2A0")]
126    public class RuleSetState : Item {
127      [Storable]
128      public List<RegressionRuleModel> Rules = new List<RegressionRuleModel>();
129      [Storable]
130      public IReadOnlyList<int> TrainingRows = new List<int>();
131      [Storable]
132      public IReadOnlyList<int> PruningRows = new List<int>();
133
134      //State.Code values denote the current action (for pausing)
135      //0...nothing has been done;
136      //1...splitting nodes;
137      [Storable]
138      public int Code = 0;
139
140      #region HLConstructors & Cloning
141      [StorableConstructor]
142      protected RuleSetState(StorableConstructorFlag _) : base(_) { }
143      protected RuleSetState(RuleSetState original, Cloner cloner) : base(original, cloner) {
144        Rules = original.Rules.Select(cloner.Clone).ToList();
145        TrainingRows = original.TrainingRows.ToList();
146        PruningRows = original.PruningRows.ToList();
147
148        Code = original.Code;
149      }
150      public RuleSetState() { }
151      public override IDeepCloneable Clone(Cloner cloner) {
152        return new RuleSetState(this, cloner);
153      }
154      #endregion
155    }
156
157    [StorableType("52E7992B-94CC-4960-AA82-1A399BE735C6")]
158    private sealed class ConfidenceRegressionRuleSetModel : RegressionRuleSetModel, IConfidenceRegressionModel {
159      #region HLConstructors & Cloning
160      [StorableConstructor]
161      private ConfidenceRegressionRuleSetModel(StorableConstructorFlag _) : base(_) { }
162      private ConfidenceRegressionRuleSetModel(ConfidenceRegressionRuleSetModel original, Cloner cloner) : base(original, cloner) { }
163      public ConfidenceRegressionRuleSetModel(string targetVariable) : base(targetVariable) { }
164      public override IDeepCloneable Clone(Cloner cloner) {
165        return new ConfidenceRegressionRuleSetModel(this, cloner);
166      }
167      #endregion
168
169      #region IConfidenceRegressionModel
170      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
171        if (Rules == null) throw new NotSupportedException("The model has not been built yet");
172        return rows.Select(row => GetEstimatedVariance(dataset, row));
173      }
174      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
175        return new ConfidenceRegressionSolution(this, problemData);
176      }
177      private double GetEstimatedVariance(IDataset dataset, int row) {
178        foreach (var rule in Rules) {
179          if (rule.Covers(dataset, row)) return ((IConfidenceRegressionModel)rule).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
180        }
181        throw new ArgumentException("Instance is not covered by any rule");
182      }
183      #endregion
184    }
185  }
186}
Note: See TracBrowser for help on using the repository browser.