Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Algorithms.DataAnalysis.DecisionTrees/3.4/MetaModels/RegressionNodeTreeModel.cs @ 18242

Last change on this file since 18242 was 17180, checked in by swagner, 5 years ago

#2875: Removed years in copyrights

File size: 6.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Optimization;
29using HeuristicLab.Problems.DataAnalysis;
30using HEAL.Attic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableType("FAF1F955-82F3-4824-9759-9D2846E831AE")]
34  public class RegressionNodeTreeModel : RegressionModel, IDecisionTreeModel {
35    public const string NumCurrentLeafsResultName = "Number of current leafs";
36    public const string RootVariableName = "Root";
37    #region Properties
38    [Storable]
39    internal RegressionNodeModel Root { get; private set; }
40    #endregion
41
42    #region HLConstructors & Cloning
43    [StorableConstructor]
44    protected RegressionNodeTreeModel(StorableConstructorFlag _) : base(_) { }
45    protected RegressionNodeTreeModel(RegressionNodeTreeModel original, Cloner cloner) : base(original, cloner) {
46      Root = cloner.Clone(original.Root);
47    }
48    protected RegressionNodeTreeModel(string targetVariable) : base(targetVariable) { }
49    public override IDeepCloneable Clone(Cloner cloner) {
50      return new RegressionNodeTreeModel(this, cloner);
51    }
52    #endregion
53
54    internal static RegressionNodeTreeModel CreateTreeModel(string targetAttr, RegressionTreeParameters regressionTreeParams) {
55      return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionNodeTreeModel(targetAttr) : new RegressionNodeTreeModel(targetAttr);
56    }
57
58    #region RegressionModel
59    public override IEnumerable<string> VariablesUsedForPrediction {
60      get { return Root.VariablesUsedForPrediction ?? new List<string>(); }
61    }
62    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
63      if (Root == null) throw new NotSupportedException("The model has not been built yet");
64      return Root.GetEstimatedValues(dataset, rows);
65    }
66    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
67      return new RegressionSolution(this, problemData);
68    }
69    #endregion
70
71    #region IDecisionTreeModel
72    public void Build(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken) {
73      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
74      //start with one node
75      if (Root == null)
76        Root = RegressionNodeModel.CreateNode(regressionTreeParams.TargetVariable, regressionTreeParams);
77
78      //split into (overfitted tree)
79      regressionTreeParams.Splitter.Split(this, trainingRows, statescope, cancellationToken);
80
81      //prune
82      regressionTreeParams.Pruning.Prune(this, trainingRows, pruningRows, statescope, cancellationToken);
83
84      //build final leaf models
85      regressionTreeParams.LeafModel.Build(this, trainingRows.Union(pruningRows).ToArray(), statescope, cancellationToken);
86    }
87
88    public void Update(IReadOnlyList<int> rows, IScope statescope, CancellationToken cancellationToken) {
89      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
90      regressionTreeParams.LeafModel.Build(this, rows, statescope, cancellationToken);
91    }
92
93    public static void Initialize(IScope stateScope) {
94      var param = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
95      stateScope.Variables.Add(new Variable(RootVariableName, RegressionNodeModel.CreateNode(param.TargetVariable, param)));
96    }
97    #endregion
98
99    public void BuildModel(IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken) {
100      var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
101      //start with one node
102      Root = RegressionNodeModel.CreateNode(regressionTreeParams.TargetVariable, regressionTreeParams);
103
104      //split into (overfitted tree)
105      regressionTreeParams.Splitter.Split(this, trainingRows, statescope, cancellationToken);
106
107      //prune
108      regressionTreeParams.Pruning.Prune(this, trainingRows, pruningRows, statescope, cancellationToken);
109    }
110
111    [StorableType("E84ACC40-5694-4E40-A947-190673643206")]
112    private sealed class ConfidenceRegressionNodeTreeModel : RegressionNodeTreeModel, IConfidenceRegressionModel {
113      #region HLConstructors & Cloning
114      [StorableConstructor]
115      private ConfidenceRegressionNodeTreeModel(StorableConstructorFlag _) : base(_) { }
116      private ConfidenceRegressionNodeTreeModel(ConfidenceRegressionNodeTreeModel original, Cloner cloner) : base(original, cloner) { }
117      public ConfidenceRegressionNodeTreeModel(string targetVariable) : base(targetVariable) { }
118      public override IDeepCloneable Clone(Cloner cloner) {
119        return new ConfidenceRegressionNodeTreeModel(this, cloner);
120      }
121      #endregion
122
123      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
124        if (Root == null) throw new NotSupportedException("The model has not been built yet");
125        return ((IConfidenceRegressionModel)Root).GetEstimatedVariances(dataset, rows);
126      }
127      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
128        return new ConfidenceRegressionSolution(this, problemData);
129      }
130    }
131  }
132}
Note: See TracBrowser for help on using the repository browser.