source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Regression.cs @ 15967

Last change on this file since 15967 was 15967, checked in by bwerth, 12 months ago

#2847 added logistic dampening and some minor changes

File size: 14.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.PermutationEncoding;
9using HeuristicLab.Optimization;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.PluginInfrastructure;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Random;
15
16namespace HeuristicLab.Algorithms.DataAnalysis {
17  [StorableClass]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
19  [Item("M5RegressionTree", "A M5 regression tree / rule set")]
20  public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
21    public override bool SupportsPause {
22      get { return true; }
23    }
24
25    public const string RegressionTreeParameterVariableName = "RegressionTreeParameters";
26    public const string ModelVariableName = "Model";
27    public const string PruningSetVariableName = "PruningSet";
28    public const string TrainingSetVariableName = "TrainingSet";
29
30    #region Parametername
31    private const string GenerateRulesParameterName = "GenerateRules";
32    private const string HoldoutSizeParameterName = "HoldoutSize";
33    private const string SpliterParameterName = "Splitter";
34    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
35    private const string LeafModelParameterName = "LeafModel";
36    private const string PruningTypeParameterName = "PruningType";
37    private const string SeedParameterName = "Seed";
38    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
39    private const string UseHoldoutParameterName = "UseHoldout";
40    #endregion
41
42    #region Parameter properties
43    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
44      get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
45    }
46    public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
47      get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
48    }
49    public IConstrainedValueParameter<ISplitter> ImpurityParameter {
50      get { return (IConstrainedValueParameter<ISplitter>)Parameters[SpliterParameterName]; }
51    }
52    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
53      get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
54    }
55    public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
56      get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
57    }
58    public IConstrainedValueParameter<IPruning> PruningTypeParameter {
59      get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
60    }
61    public IFixedValueParameter<IntValue> SeedParameter {
62      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
63    }
64    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
65      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
66    }
67    public IFixedValueParameter<BoolValue> UseHoldoutParameter {
68      get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
69    }
70    #endregion
71
72    #region Properties
73    public bool GenerateRules {
74      get { return GenerateRulesParameter.Value.Value; }
75    }
76    public double HoldoutSize {
77      get { return HoldoutSizeParameter.Value.Value; }
78    }
79    public ISplitter Splitter {
80      get { return ImpurityParameter.Value; }
81    }
82    public int MinimalNodeSize {
83      get { return MinimalNodeSizeParameter.Value.Value; }
84    }
85    public ILeafModel LeafModel {
86      get { return LeafModelParameter.Value; }
87    }
88    public IPruning Pruning {
89      get { return PruningTypeParameter.Value; }
90    }
91    public int Seed {
92      get { return SeedParameter.Value.Value; }
93    }
94    public bool SetSeedRandomly {
95      get { return SetSeedRandomlyParameter.Value.Value; }
96    }
97    public bool UseHoldout {
98      get { return UseHoldoutParameter.Value.Value; }
99    }
100    #endregion
101
102    #region State
103    [Storable]
104    private IScope stateScope;
105    #endregion
106
107    #region Constructors and Cloning
108    [StorableConstructor]
109    private M5Regression(bool deserializing) : base(deserializing) { }
110    private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) {
111      stateScope = cloner.Clone(stateScope);
112    }
113    public M5Regression() {
114      var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
115      var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
116      var impuritySet = new ItemSet<ISplitter>(ApplicationManager.Manager.GetInstances<ISplitter>());
117      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(false)));
118      Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning", new PercentValue(0.2)));
119      Parameters.Add(new ConstrainedValueParameter<ISplitter>(SpliterParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<M5Splitter>().First()));
120      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
121      Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
122      Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<ComplexityPruning>().First()));
123      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
124      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
125      Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data ", new BoolValue(false)));
126      Problem = new RegressionProblem();
127    }
128    public override IDeepCloneable Clone(Cloner cloner) {
129      return new M5Regression(this, cloner);
130    }
131    #endregion
132
133    protected override void Initialize(CancellationToken cancellationToken) {
134      base.Initialize(cancellationToken);
135      var random = new MersenneTwister();
136      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
137      random.Reset(Seed);
138      stateScope = InitializeScope(random, Problem.ProblemData, Pruning, MinimalNodeSize, LeafModel, Splitter, GenerateRules, UseHoldout, HoldoutSize);
139      stateScope.Variables.Add(new Variable("Algorithm", this));
140      Results.AddOrUpdateResult("StateScope", stateScope);
141    }
142
143    protected override void Run(CancellationToken cancellationToken) {
144      var model = Build(stateScope, Results, cancellationToken);
145      AnalyzeSolution(model.CreateRegressionSolution(Problem.ProblemData), Results, Problem.ProblemData);
146    }
147
148    #region Static Interface
149    public static IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, IRandom random, ILeafModel leafModel = null, ISplitter splitter = null, IPruning pruning = null,
150      bool useHoldout = false, double holdoutSize = 0.2, int minimumLeafSize = 1, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
151      if (leafModel == null) leafModel = new LinearLeaf();
152      if (splitter == null) splitter = new M5Splitter();
153      if (cancellationToken == null) cancellationToken = CancellationToken.None;
154      if (pruning == null) pruning = new ComplexityPruning();
155
156      var stateScope = InitializeScope(random, problemData, pruning, minimumLeafSize, leafModel, splitter, generateRules, useHoldout, holdoutSize);
157      var model = Build(stateScope, results, cancellationToken.Value);
158      return model.CreateRegressionSolution(problemData);
159    }
160
161    public static void UpdateModel(IM5Model model, IRegressionProblemData problemData, IRandom random, ILeafModel leafModel, CancellationToken? cancellationToken = null) {
162      if (cancellationToken == null) cancellationToken = CancellationToken.None;
163      var regressionTreeParameters = new RegressionTreeParameters(leafModel, problemData, random);
164      var scope = new Scope();
165      scope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParameters));
166      leafModel.Initialize(scope);
167      model.Update(problemData.TrainingIndices.ToList(), scope, cancellationToken.Value);
168    }
169    #endregion
170
171    #region Helpers
172    private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) {
173      var stateScope = new Scope("RegressionTreeStateScope");
174
175      //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
176      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
177      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
178      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression supports only double valued input or output features.");
179      var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
180      if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
181        throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
182      var trainingData = new Dataset(vars, doubles);
183      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
184      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
185      pd.TrainingPartition.Start = 0;
186
187      //store regression tree parameters
188      var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);
189      stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));
190
191      //initialize tree operators
192      pruning.Initialize(stateScope);
193      splitter.Initialize(stateScope);
194      leafModel.Initialize(stateScope);
195
196      //store unbuilt model
197      IItem model;
198      if (generateRules) {
199        model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
200        RegressionRuleSetModel.Initialize(stateScope);
201      }
202      else {
203        model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
204      }
205      stateScope.Variables.Add(new Variable(ModelVariableName, model));
206
207      //store training & pruning indices
208      IReadOnlyList<int> trainingSet, pruningSet;
209      GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
210      stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
211      stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));
212
213      return stateScope;
214    }
215
216    private static IRegressionModel Build(IScope stateScope, ResultCollection results, CancellationToken cancellationToken) {
217      var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[RegressionTreeParameterVariableName].Value;
218      var model = (IM5Model)stateScope.Variables[ModelVariableName].Value;
219      var trainingRows = (IntArray)stateScope.Variables[TrainingSetVariableName].Value;
220      var pruningRows = (IntArray)stateScope.Variables[PruningSetVariableName].Value;
221      if (1 > trainingRows.Length)
222        return new PreconstructedLinearModel(new Dictionary<string, double>(), 0, regressionTreeParams.TargetVariable);
223      if (regressionTreeParams.MinLeafSize > trainingRows.Length) {
224        var targets = regressionTreeParams.Data.GetDoubleValues(regressionTreeParams.TargetVariable).ToArray();
225        return new PreconstructedLinearModel(new Dictionary<string, double>(), targets.Average(), regressionTreeParams.TargetVariable);
226      }
227      model.Build(trainingRows.ToArray(), pruningRows.ToArray(), stateScope, results, cancellationToken);
228      return model;
229    }
230
231    private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
232      if (!useHoldout) {
233        training = allrows;
234        pruning = allrows;
235        return;
236      }
237      var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
238      var cut = (int)(holdoutSize * allrows.Count);
239      pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
240      training = perm.Take(cut).Select(i => allrows[i]).ToArray();
241    }
242
243    private void AnalyzeSolution(IRegressionSolution solution, ResultCollection results, IRegressionProblemData problemData) {
244      results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
245
246      Dictionary<string, int> frequencies = null;
247
248      var tree = solution.Model as RegressionNodeTreeModel;
249      if (tree != null) {
250        results.Add(RegressionTreeAnalyzer.CreateLeafDepthHistogram(tree));
251        frequencies = RegressionTreeAnalyzer.GetTreeVariableFrequences(tree);
252        RegressionTreeAnalyzer.AnalyzeNodes(tree, results, problemData);
253      }
254
255      var ruleSet = solution.Model as RegressionRuleSetModel;
256      if (ruleSet != null) {
257        results.Add(RegressionTreeAnalyzer.CreateRulesResult(ruleSet, problemData, "M5Rules", true));
258        frequencies = RegressionTreeAnalyzer.GetRuleVariableFrequences(ruleSet);
259        results.Add(RegressionTreeAnalyzer.CreateCoverageDiagram(ruleSet, problemData));
260      }
261
262      //Variable frequencies
263      if (frequencies != null) {
264        var sum = frequencies.Values.Sum();
265        sum = sum == 0 ? 1 : sum;
266        var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
267          ElementNames = frequencies.Select(i => i.Key)
268        };
269        results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
270      }
271
272      var pruning = Pruning as ComplexityPruning;
273      if (pruning != null && tree != null)
274        RegressionTreeAnalyzer.PruningChart(tree, pruning, results);
275    }
276    #endregion
277  }
278}
Note: See TracBrowser for help on using the repository browser.