Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Regression.cs @ 15614

Last change on this file since 15614 was 15614, checked in by bwerth, 6 years ago

#2847 made changes to M5 according to review comments

File size: 11.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.PermutationEncoding;
9using HeuristicLab.Optimization;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.PluginInfrastructure;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Random;
15
16namespace HeuristicLab.Algorithms.DataAnalysis {
17  [StorableClass]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
19  [Item("M5RegressionTree", "A M5 regression tree / rule set")]
20  public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
21    #region Parametername
22    private const string GenerateRulesParameterName = "GenerateRules";
23    private const string HoldoutSizeParameterName = "HoldoutSize";
24    private const string SpliterParameterName = "Spliter";
25    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
26    private const string LeafModelParameterName = "LeafModel";
27    private const string PruningTypeParameterName = "PruningType";
28    private const string SeedParameterName = "Seed";
29    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
30    private const string UseHoldoutParameterName = "UseHoldout";
31    #endregion
32
33    #region Parameter properties
34    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
35      get { return (IFixedValueParameter<BoolValue>)Parameters[GenerateRulesParameterName]; }
36    }
37    public IFixedValueParameter<PercentValue> HoldoutSizeParameter {
38      get { return (IFixedValueParameter<PercentValue>)Parameters[HoldoutSizeParameterName]; }
39    }
40    public IConstrainedValueParameter<ISpliter> ImpurityParameter {
41      get { return (IConstrainedValueParameter<ISpliter>)Parameters[SpliterParameterName]; }
42    }
43    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters[MinimalNodeSizeParameterName]; }
45    }
46    public IConstrainedValueParameter<ILeafModel> LeafModelParameter {
47      get { return (IConstrainedValueParameter<ILeafModel>)Parameters[LeafModelParameterName]; }
48    }
49    public IConstrainedValueParameter<IPruning> PruningTypeParameter {
50      get { return (IConstrainedValueParameter<IPruning>)Parameters[PruningTypeParameterName]; }
51    }
52    public IFixedValueParameter<IntValue> SeedParameter {
53      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
54    }
55    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
56      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
57    }
58    public IFixedValueParameter<BoolValue> UseHoldoutParameter {
59      get { return (IFixedValueParameter<BoolValue>)Parameters[UseHoldoutParameterName]; }
60    }
61    #endregion
62
63    #region Properties
64    public bool GenerateRules {
65      get { return GenerateRulesParameter.Value.Value; }
66    }
67    public double HoldoutSize {
68      get { return HoldoutSizeParameter.Value.Value; }
69    }
70    public ISpliter Split {
71      get { return ImpurityParameter.Value; }
72    }
73    public int MinimalNodeSize {
74      get { return MinimalNodeSizeParameter.Value.Value; }
75    }
76    public ILeafModel LeafModel {
77      get { return LeafModelParameter.Value; }
78    }
79    public IPruning Pruning {
80      get { return PruningTypeParameter.Value; }
81    }
82    public int Seed {
83      get { return SeedParameter.Value.Value; }
84    }
85    public bool SetSeedRandomly {
86      get { return SetSeedRandomlyParameter.Value.Value; }
87    }
88    public bool UseHoldout {
89      get { return UseHoldoutParameter.Value.Value; }
90    }
91    #endregion
92
93    #region Constructors and Cloning
94    [StorableConstructor]
95    private M5Regression(bool deserializing) : base(deserializing) { }
96    private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) { }
97    public M5Regression() {
98      var modelSet = new ItemSet<ILeafModel>(ApplicationManager.Manager.GetInstances<ILeafModel>());
99      var pruningSet = new ItemSet<IPruning>(ApplicationManager.Manager.GetInstances<IPruning>());
100      var impuritySet = new ItemSet<ISpliter>(ApplicationManager.Manager.GetInstances<ISpliter>());
101      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(false)));
102      Parameters.Add(new FixedValueParameter<PercentValue>(HoldoutSizeParameterName, "How much of the training set shall be reserved for pruning", new PercentValue(0.2)));
103      Parameters.Add(new ConstrainedValueParameter<ISpliter>(SpliterParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<M5Spliter>().First()));
104      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
105      Parameters.Add(new ConstrainedValueParameter<ILeafModel>(LeafModelParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
106      Parameters.Add(new ConstrainedValueParameter<IPruning>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LinearBottomUpPruning>().First()));
107      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
108      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
109      Parameters.Add(new FixedValueParameter<BoolValue>(UseHoldoutParameterName, "True if a holdout set should be generated, false if splitting and pruning shall be performed on the same data ", new BoolValue(false)));
110      Problem = new RegressionProblem();
111    }
112    public override IDeepCloneable Clone(Cloner cloner) {
113      return new M5Regression(this, cloner);
114    }
115    #endregion
116
117    protected override void Run(CancellationToken cancellationToken) {
118      var random = new MersenneTwister();
119      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
120      random.Reset(Seed);
121      var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafModel, Split, Pruning, UseHoldout, HoldoutSize, MinimalNodeSize, GenerateRules, Results, cancellationToken);
122      AnalyzeSolution(solution);
123    }
124
125    #region Static Interface
126    public static IRegressionSolution CreateM5RegressionSolution(IRegressionProblemData problemData, IRandom random,
127      ILeafModel leafModel = null, ISpliter spliter = null, IPruning pruning = null,
128      bool useHoldout = false, double holdoutSize = 0.2, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null, CancellationToken? cancellationToken = null) {
129      //set default values
130      if (leafModel == null) leafModel = new LinearLeaf();
131      if (spliter == null) spliter = new M5Spliter();
132      if (cancellationToken == null) cancellationToken = CancellationToken.None;
133      if (pruning == null) pruning = new M5LeafBottomUpPruning();
134
135      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
136      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
137      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression supports only double valued input or output features.");
138
139      var values = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
140      if (values.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
141        throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
142
143      var trainingData = new Dataset(vars, values);
144      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
145      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
146      pd.TrainingPartition.Start = 0;
147
148      //create & build Model
149      var m5Params = new M5Parameters(pruning, minNumInstances, leafModel, pd, random, spliter, results);
150
151      IReadOnlyList<int> trainingRows, pruningRows;
152      GeneratePruningSet(problemData.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingRows, out pruningRows);
153
154      IM5Model model;
155      if (generateRules)
156        model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
157      else
158        model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
159
160      model.Build(trainingRows, pruningRows, m5Params, cancellationToken.Value);
161      return model.CreateRegressionSolution(problemData);
162    }
163
164    public static void UpdateM5Model(IRegressionModel model, IRegressionProblemData problemData, IRandom random,
165      ILeafModel leafModel, CancellationToken? cancellationToken = null) {
166      var m5Model = model as IM5Model;
167      if (m5Model == null) throw new ArgumentException("This type of model can not be updated");
168      UpdateM5Model(m5Model, problemData, random, leafModel, cancellationToken);
169    }
170
171    private static void UpdateM5Model(IM5Model model, IRegressionProblemData problemData, IRandom random,
172      ILeafModel leafModel = null, CancellationToken? cancellationToken = null) {
173      if (cancellationToken == null) cancellationToken = CancellationToken.None;
174      var m5Params = new M5Parameters(leafModel, problemData, random);
175      model.Update(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
176    }
177    #endregion
178
179    #region Helpers
180    private static void GeneratePruningSet(IReadOnlyList<int> allrows, IRandom random, bool useHoldout, double holdoutSize, out IReadOnlyList<int> training, out IReadOnlyList<int> pruning) {
181      if (!useHoldout) {
182        training = allrows;
183        pruning = allrows;
184        return;
185      }
186      var perm = new Permutation(PermutationTypes.Absolute, allrows.Count, random);
187      var cut = (int)(holdoutSize * allrows.Count);
188      pruning = perm.Take(cut).Select(i => allrows[i]).ToArray();
189      training = perm.Take(cut).Select(i => allrows[i]).ToArray();
190    }
191
192    private void AnalyzeSolution(IRegressionSolution solution) {
193      Results.Add(new Result("RegressionSolution", (IItem)solution.Clone()));
194
195      Dictionary<string, int> frequencies;
196      if (!GenerateRules) {
197        Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel)solution.Model));
198        frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel)solution.Model);
199      }
200      else {
201        Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel)solution.Model, Problem.ProblemData, "M5TreeResult", true));
202        frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel)solution.Model);
203        Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel)solution.Model, Problem.ProblemData));
204      }
205
206      //Variable frequencies
207      var sum = frequencies.Values.Sum();
208      sum = sum == 0 ? 1 : sum;
209      var impactArray = new DoubleArray(frequencies.Select(i => (double)i.Value / sum).ToArray()) {
210        ElementNames = frequencies.Select(i => i.Key)
211      };
212      Results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
213    }
214    #endregion
215  }
216}
Note: See TracBrowser for help on using the repository browser.