Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/MetaModels/RegressionRuleModel.cs @ 18183

Last change on this file since 18183 was 17209, checked in by gkronber, 5 years ago

#2994: merged r17132:17198 from trunk to branch

File size: 7.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using System.Threading;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Optimization;
30using HeuristicLab.Problems.DataAnalysis;
31using HEAL.Attic;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [StorableType("425AF262-A756-4E9A-B76F-4D2480BEA4FD")]
35  public class RegressionRuleModel : RegressionModel, IDecisionTreeModel {
36    #region Properties
37    [Storable]
38    public string[] SplitAttributes { get; set; }
39    [Storable]
40    private double[] SplitValues { get; set; }
41    [Storable]
42    private Comparison[] Comparisons { get; set; }
43    [Storable]
44    private IRegressionModel RuleModel { get; set; }
45    [Storable]
46    private IReadOnlyList<string> variables;
47    #endregion
48
49    #region HLConstructors
50    [StorableConstructor]
51    protected RegressionRuleModel(StorableConstructorFlag _) : base(_) { }
52    protected RegressionRuleModel(RegressionRuleModel original, Cloner cloner) : base(original, cloner) {
53      if (original.SplitAttributes != null) SplitAttributes = original.SplitAttributes.ToArray();
54      if (original.SplitValues != null) SplitValues = original.SplitValues.ToArray();
55      if (original.Comparisons != null) Comparisons = original.Comparisons.ToArray();
56      RuleModel = cloner.Clone(original.RuleModel);
57      if (original.variables != null) variables = original.variables.ToList();
58    }
59    private RegressionRuleModel(string target) : base(target) { }
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new RegressionRuleModel(this, cloner);
62    }
63    #endregion
64
65    internal static RegressionRuleModel CreateRuleModel(string target, RegressionTreeParameters regressionTreeParams) {
66      return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionRuleModel(target) : new RegressionRuleModel(target);
67    }
68
69    #region IRegressionModel
70    public override IEnumerable<string> VariablesUsedForPrediction {
71      get { return variables; }
72    }
73
74    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
75      if (RuleModel == null) throw new NotSupportedException("The model has not been built correctly");
76      return RuleModel.GetEstimatedValues(dataset, rows);
77    }
78
79    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
80      return new RegressionSolution(this, problemData);
81    }
82    #endregion
83
84    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken) {
85      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
86      variables = regressionTreeParams.AllowedInputVariables.ToList();
87
88      //build tree and select node with maximum coverage
89      var tree = RegressionNodeTreeModel.CreateTreeModel(regressionTreeParams.TargetVariable, regressionTreeParams);
90      tree.BuildModel(trainingRows, pruningRows, statescope, results, cancellationToken);
91      var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First();
92
93      var satts = new List<string>();
94      var svals = new List<double>();
95      var reops = new List<Comparison>();
96
97      //extract splits
98      for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent) {
99        satts.Add(temp.Parent.SplitAttribute);
100        svals.Add(temp.Parent.SplitValue);
101        reops.Add(temp.Parent.Left == temp ? Comparison.LessEqual : Comparison.Greater);
102      }
103      Comparisons = reops.ToArray();
104      SplitAttributes = satts.ToArray();
105      SplitValues = svals.ToArray();
106      int np;
107      RuleModel = regressionTreeParams.LeafModel.BuildModel(trainingRows.Union(pruningRows).Where(r => Covers(regressionTreeParams.Data, r)).ToArray(), regressionTreeParams, cancellationToken, out np);
108    }
109
110    public void Update(IReadOnlyList<int> rows, IScope statescope, CancellationToken cancellationToken) {
111      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
112      int np;
113      RuleModel = regressionTreeParams.LeafModel.BuildModel(rows, regressionTreeParams, cancellationToken, out np);
114    }
115
116    public bool Covers(IDataset dataset, int row) {
117      return !SplitAttributes.Where((t, i) => !Comparisons[i].Compare(dataset.GetDoubleValue(t, row), SplitValues[i])).Any();
118    }
119
120    public string ToCompactString() {
121      var mins = new Dictionary<string, double>();
122      var maxs = new Dictionary<string, double>();
123      for (var i = 0; i < SplitAttributes.Length; i++) {
124        var n = SplitAttributes[i];
125        var v = SplitValues[i];
126        if (!mins.ContainsKey(n)) mins.Add(n, double.NegativeInfinity);
127        if (!maxs.ContainsKey(n)) maxs.Add(n, double.PositiveInfinity);
128        if (Comparisons[i] == Comparison.LessEqual) maxs[n] = Math.Min(maxs[n], v);
129        else mins[n] = Math.Max(mins[n], v);
130      }
131      if (maxs.Count == 0) return "";
132      var s = new StringBuilder();
133      foreach (var key in maxs.Keys)
134        s.Append(string.Format("{0} ∈ [{1:e2}; {2:e2}] && ", key, mins[key], maxs[key]));
135      s.Remove(s.Length - 4, 4);
136      return s.ToString();
137    }
138
139    [StorableType("7302AA30-9F58-42F3-BF6A-ECF1536508AB")]
140    private sealed class ConfidenceRegressionRuleModel : RegressionRuleModel, IConfidenceRegressionModel {
141      #region HLConstructors
142      [StorableConstructor]
143      private ConfidenceRegressionRuleModel(StorableConstructorFlag _) : base(_) { }
144      private ConfidenceRegressionRuleModel(ConfidenceRegressionRuleModel original, Cloner cloner) : base(original, cloner) { }
145      public ConfidenceRegressionRuleModel(string targetAttr) : base(targetAttr) { }
146      public override IDeepCloneable Clone(Cloner cloner) {
147        return new ConfidenceRegressionRuleModel(this, cloner);
148      }
149      #endregion
150
151      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
152        return ((IConfidenceRegressionModel)RuleModel).GetEstimatedVariances(dataset, rows);
153      }
154
155      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
156        return new ConfidenceRegressionSolution(this, problemData);
157      }
158    }
159  }
160
161  [StorableType("152DECE4-2692-4D53-B290-974806ADCD72")]
162  internal enum Comparison {
163    LessEqual,
164    Greater
165  }
166
167  internal static class ComparisonExtentions {
168    public static bool Compare(this Comparison op, double x, double y) {
169      switch (op) {
170        case Comparison.Greater:
171          return x > y;
172        case Comparison.LessEqual:
173          return x <= y;
174        default:
175          throw new ArgumentOutOfRangeException(op.ToString(), op, null);
176      }
177    }
178  }
179}
Note: See TracBrowser for help on using the repository browser.