Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/M5Regression.cs @ 15470

Last change on this file since 15470 was 15470, checked in by bwerth, 5 years ago

#2847 worked on M5Regression

File size: 10.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Optimization;
9using HeuristicLab.Parameters;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11using HeuristicLab.PluginInfrastructure;
12using HeuristicLab.Problems.DataAnalysis;
13using HeuristicLab.Random;
14
15namespace HeuristicLab.Algorithms.DataAnalysis {
16  [StorableClass]
17  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
18  [Item("M5RegressionTree", "A M5 regression tree / rule set classifier")]
19  public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
20    #region Parametername
21    private const string GenerateRulesParameterName = "GenerateRules";
22    private const string ImpurityParameterName = "Split";
23    private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
24    private const string ModelTypeParameterName = "ModelType";
25    private const string PruningTypeParameterName = "PruningType";
26    private const string SeedParameterName = "Seed";
27    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
28    #endregion
29
30    #region Parameter properties
31    public IFixedValueParameter<BoolValue> GenerateRulesParameter {
32      get { return Parameters[GenerateRulesParameterName] as IFixedValueParameter<BoolValue>; }
33    }
34    public IConstrainedValueParameter<ISplitType> ImpurityParameter {
35      get { return Parameters[ImpurityParameterName] as IConstrainedValueParameter<ISplitType>; }
36    }
37    public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
38      get { return (IFixedValueParameter<IntValue>) Parameters[MinimalNodeSizeParameterName]; }
39    }
40    public IConstrainedValueParameter<ILeafType<IRegressionModel>> ModelTypeParameter {
41      get { return Parameters[ModelTypeParameterName] as IConstrainedValueParameter<ILeafType<IRegressionModel>>; }
42    }
43    public IConstrainedValueParameter<IPruningType> PruningTypeParameter {
44      get { return Parameters[PruningTypeParameterName] as IConstrainedValueParameter<IPruningType>; }
45    }
46    public IFixedValueParameter<IntValue> SeedParameter {
47      get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
48    }
49    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
50      get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
51    }
52    #endregion
53
54    #region Properties
55    public bool GenerateRules {
56      get { return GenerateRulesParameter.Value.Value; }
57    }
58    public ISplitType Split {
59      get { return ImpurityParameter.Value; }
60    }
61    public int MinimalNodeSize {
62      get { return MinimalNodeSizeParameter.Value.Value; }
63    }
64    public ILeafType<IRegressionModel> LeafType {
65      get { return ModelTypeParameter.Value; }
66    }
67    public IPruningType PruningType {
68      get { return PruningTypeParameter.Value; }
69    }
70    public int Seed {
71      get { return SeedParameter.Value.Value; }
72    }
73    public bool SetSeedRandomly {
74      get { return SetSeedRandomlyParameter.Value.Value; }
75    }
76    #endregion
77
78    #region Constructors and Cloning
79    [StorableConstructor]
80    private M5Regression(bool deserializing) : base(deserializing) { }
81    private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) { }
82    public M5Regression() {
83      var modelSet = new ItemSet<ILeafType<IRegressionModel>>(ApplicationManager.Manager.GetInstances<ILeafType<IRegressionModel>>());
84      var pruningSet = new ItemSet<IPruningType>(ApplicationManager.Manager.GetInstances<IPruningType>());
85      var impuritySet = new ItemSet<ISplitType>(ApplicationManager.Manager.GetInstances<ISplitType>());
86      Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(true)));
87      Parameters.Add(new ConstrainedValueParameter<ISplitType>(ImpurityParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<OrderSplitType>().First()));
88      Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
89      Parameters.Add(new ConstrainedValueParameter<ILeafType<IRegressionModel>>(ModelTypeParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
90      Parameters.Add(new ConstrainedValueParameter<IPruningType>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LeafPruning>().First()));
91      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
92      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
93      Problem = new RegressionProblem();
94    }
95    public override IDeepCloneable Clone(Cloner cloner) {
96      return new M5Regression(this, cloner);
97    }
98    #endregion
99
100    protected override void Run(CancellationToken cancellationToken) {
101      var random = new MersenneTwister();
102      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
103      random.Reset(Seed);
104      var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafType, Split, PruningType, cancellationToken, MinimalNodeSize, GenerateRules, Results);
105      AnalyzeSolution(solution);
106    }
107
108    #region Static Interface
109    public static IRegressionSolution CreateM5RegressionSolution(IRegressionProblemData problemData, IRandom random,
110      ILeafType<IRegressionModel> leafType = null, ISplitType splitType = null, IPruningType pruningType = null,
111      CancellationToken? cancellationToken = null, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null) {
112      //set default values
113      if (leafType == null) leafType = new LinearLeaf();
114      if (splitType == null) splitType = new OrderSplitType();
115      if (cancellationToken == null) cancellationToken = CancellationToken.None;
116      if (pruningType == null) pruningType = new M5LeafPruning();
117
118
119      var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
120      var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
121      if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression does not support non-double valued input or output features.");
122
123      var values = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
124      if (values.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
125        throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
126      var trainingData = new Dataset(vars, values);
127      var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
128      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
129      pd.TrainingPartition.Start = 0;
130
131      //create & build Model
132      var m5Params = new M5CreationParameters(pruningType, minNumInstances, leafType, pd, random, splitType, results);
133
134      IReadOnlyList<int> t, h;
135      pruningType.GenerateHoldOutSet(problemData.TrainingIndices.ToArray(), random, out t, out h);
136
137      if (generateRules) {
138        IM5MetaModel model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
139        model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
140        return model.CreateRegressionSolution(problemData);
141      }
142      else {
143        IM5MetaModel model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
144        model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
145        return model.CreateRegressionSolution(problemData);
146      }
147    }
148
149    public static void UpdateM5Model(M5TreeModel model, IRegressionProblemData problemData, IRandom random,
150      ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
151      UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
152    }
153
154    public static void UpdateM5Model(M5RuleSetModel model, IRegressionProblemData problemData, IRandom random,
155      ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
156      UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
157    }
158
159    private static void UpdateM5Model(IM5MetaModel model, IRegressionProblemData problemData, IRandom random,
160      ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
161      if (cancellationToken == null) cancellationToken = CancellationToken.None;
162      var m5Params = new M5UpdateParameters(leafType, problemData, random);
163      model.UpdateModel(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
164    }
165    #endregion
166
167    #region Helpers
168    private void AnalyzeSolution(IRegressionSolution solution) {
169      Results.Add(new Result("RegressionSolution", (IItem) solution.Clone()));
170
171      Dictionary<string, int> frequencies;
172      if (!GenerateRules) {
173        Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel) solution.Model));
174        frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel) solution.Model);
175      }
176      else {
177        Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel) solution.Model, Problem.ProblemData, "M5TreeResult", true));
178        frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel) solution.Model);
179        Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel) solution.Model, Problem.ProblemData));
180      }
181
182      //Variable frequencies
183      var sum = frequencies.Values.Sum();
184      sum = sum == 0 ? 1 : sum;
185      var impactArray = new DoubleArray(frequencies.Select(i => (double) i.Value / sum).ToArray()) {
186        ElementNames = frequencies.Select(i => i.Key)
187      };
188      Results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
189    }
190    #endregion
191  }
192}
Note: See TracBrowser for help on using the repository browser.