1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Threading;
|
---|
5 | using HeuristicLab.Common;
|
---|
6 | using HeuristicLab.Core;
|
---|
7 | using HeuristicLab.Data;
|
---|
8 | using HeuristicLab.Optimization;
|
---|
9 | using HeuristicLab.Parameters;
|
---|
10 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
11 | using HeuristicLab.PluginInfrastructure;
|
---|
12 | using HeuristicLab.Problems.DataAnalysis;
|
---|
13 | using HeuristicLab.Random;
|
---|
14 |
|
---|
15 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
16 | [StorableClass]
|
---|
17 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
|
---|
18 | [Item("M5RegressionTree", "A M5 regression tree / rule set classifier")]
|
---|
19 | public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
20 | #region Parametername
|
---|
21 | private const string GenerateRulesParameterName = "GenerateRules";
|
---|
22 | private const string ImpurityParameterName = "Split";
|
---|
23 | private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
|
---|
24 | private const string ModelTypeParameterName = "ModelType";
|
---|
25 | private const string PruningTypeParameterName = "PruningType";
|
---|
26 | private const string SeedParameterName = "Seed";
|
---|
27 | private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
|
---|
28 | #endregion
|
---|
29 |
|
---|
30 | #region Parameter properties
|
---|
31 | public IFixedValueParameter<BoolValue> GenerateRulesParameter {
|
---|
32 | get { return Parameters[GenerateRulesParameterName] as IFixedValueParameter<BoolValue>; }
|
---|
33 | }
|
---|
34 | public IConstrainedValueParameter<ISplitType> ImpurityParameter {
|
---|
35 | get { return Parameters[ImpurityParameterName] as IConstrainedValueParameter<ISplitType>; }
|
---|
36 | }
|
---|
37 | public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
|
---|
38 | get { return (IFixedValueParameter<IntValue>) Parameters[MinimalNodeSizeParameterName]; }
|
---|
39 | }
|
---|
40 | public IConstrainedValueParameter<ILeafType<IRegressionModel>> ModelTypeParameter {
|
---|
41 | get { return Parameters[ModelTypeParameterName] as IConstrainedValueParameter<ILeafType<IRegressionModel>>; }
|
---|
42 | }
|
---|
43 | public IConstrainedValueParameter<IPruningType> PruningTypeParameter {
|
---|
44 | get { return Parameters[PruningTypeParameterName] as IConstrainedValueParameter<IPruningType>; }
|
---|
45 | }
|
---|
46 | public IFixedValueParameter<IntValue> SeedParameter {
|
---|
47 | get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
|
---|
48 | }
|
---|
49 | public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
|
---|
50 | get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
|
---|
51 | }
|
---|
52 | #endregion
|
---|
53 |
|
---|
54 | #region Properties
|
---|
55 | public bool GenerateRules {
|
---|
56 | get { return GenerateRulesParameter.Value.Value; }
|
---|
57 | }
|
---|
58 | public ISplitType Split {
|
---|
59 | get { return ImpurityParameter.Value; }
|
---|
60 | }
|
---|
61 | public int MinimalNodeSize {
|
---|
62 | get { return MinimalNodeSizeParameter.Value.Value; }
|
---|
63 | }
|
---|
64 | public ILeafType<IRegressionModel> LeafType {
|
---|
65 | get { return ModelTypeParameter.Value; }
|
---|
66 | }
|
---|
67 | public IPruningType PruningType {
|
---|
68 | get { return PruningTypeParameter.Value; }
|
---|
69 | }
|
---|
70 | public int Seed {
|
---|
71 | get { return SeedParameter.Value.Value; }
|
---|
72 | }
|
---|
73 | public bool SetSeedRandomly {
|
---|
74 | get { return SetSeedRandomlyParameter.Value.Value; }
|
---|
75 | }
|
---|
76 | #endregion
|
---|
77 |
|
---|
78 | #region Constructors and Cloning
|
---|
79 | [StorableConstructor]
|
---|
80 | private M5Regression(bool deserializing) : base(deserializing) { }
|
---|
81 | private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) { }
|
---|
82 | public M5Regression() {
|
---|
83 | var modelSet = new ItemSet<ILeafType<IRegressionModel>>(ApplicationManager.Manager.GetInstances<ILeafType<IRegressionModel>>());
|
---|
84 | var pruningSet = new ItemSet<IPruningType>(ApplicationManager.Manager.GetInstances<IPruningType>());
|
---|
85 | var impuritySet = new ItemSet<ISplitType>(ApplicationManager.Manager.GetInstances<ISplitType>());
|
---|
86 | Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(true)));
|
---|
87 | Parameters.Add(new ConstrainedValueParameter<ISplitType>(ImpurityParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<OrderSplitType>().First()));
|
---|
88 | Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
|
---|
89 | Parameters.Add(new ConstrainedValueParameter<ILeafType<IRegressionModel>>(ModelTypeParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
|
---|
90 | Parameters.Add(new ConstrainedValueParameter<IPruningType>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LeafPruning>().First()));
|
---|
91 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
92 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
93 | Problem = new RegressionProblem();
|
---|
94 | }
|
---|
95 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
96 | return new M5Regression(this, cloner);
|
---|
97 | }
|
---|
98 | #endregion
|
---|
99 |
|
---|
100 | protected override void Run(CancellationToken cancellationToken) {
|
---|
101 | var random = new MersenneTwister();
|
---|
102 | if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
|
---|
103 | random.Reset(Seed);
|
---|
104 | var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafType, Split, PruningType, cancellationToken, MinimalNodeSize, GenerateRules, Results);
|
---|
105 | AnalyzeSolution(solution);
|
---|
106 | }
|
---|
107 |
|
---|
108 | #region Static Interface
|
---|
109 | public static IRegressionSolution CreateM5RegressionSolution(IRegressionProblemData problemData, IRandom random,
|
---|
110 | ILeafType<IRegressionModel> leafType = null, ISplitType splitType = null, IPruningType pruningType = null,
|
---|
111 | CancellationToken? cancellationToken = null, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null) {
|
---|
112 | //set default values
|
---|
113 | if (leafType == null) leafType = new LinearLeaf();
|
---|
114 | if (splitType == null) splitType = new OrderSplitType();
|
---|
115 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
116 | if (pruningType == null) pruningType = new M5LeafPruning();
|
---|
117 |
|
---|
118 |
|
---|
119 | var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
|
---|
120 | var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
|
---|
121 | if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression does not support non-double valued input or output features.");
|
---|
122 |
|
---|
123 | var values = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
|
---|
124 | if (values.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
|
---|
125 | throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
|
---|
126 | var trainingData = new Dataset(vars, values);
|
---|
127 | var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
|
---|
128 | pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
|
---|
129 | pd.TrainingPartition.Start = 0;
|
---|
130 |
|
---|
131 | //create & build Model
|
---|
132 | var m5Params = new M5CreationParameters(pruningType, minNumInstances, leafType, pd, random, splitType, results);
|
---|
133 |
|
---|
134 | IReadOnlyList<int> t, h;
|
---|
135 | pruningType.GenerateHoldOutSet(problemData.TrainingIndices.ToArray(), random, out t, out h);
|
---|
136 |
|
---|
137 | if (generateRules) {
|
---|
138 | IM5MetaModel model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
|
---|
139 | model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
|
---|
140 | return model.CreateRegressionSolution(problemData);
|
---|
141 | }
|
---|
142 | else {
|
---|
143 | IM5MetaModel model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
|
---|
144 | model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
|
---|
145 | return model.CreateRegressionSolution(problemData);
|
---|
146 | }
|
---|
147 | }
|
---|
148 |
|
---|
149 | public static void UpdateM5Model(M5TreeModel model, IRegressionProblemData problemData, IRandom random,
|
---|
150 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
151 | UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
|
---|
152 | }
|
---|
153 |
|
---|
154 | public static void UpdateM5Model(M5RuleSetModel model, IRegressionProblemData problemData, IRandom random,
|
---|
155 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
156 | UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
|
---|
157 | }
|
---|
158 |
|
---|
159 | private static void UpdateM5Model(IM5MetaModel model, IRegressionProblemData problemData, IRandom random,
|
---|
160 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
161 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
162 | var m5Params = new M5UpdateParameters(leafType, problemData, random);
|
---|
163 | model.UpdateModel(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
|
---|
164 | }
|
---|
165 | #endregion
|
---|
166 |
|
---|
167 | #region Helpers
|
---|
168 | private void AnalyzeSolution(IRegressionSolution solution) {
|
---|
169 | Results.Add(new Result("RegressionSolution", (IItem) solution.Clone()));
|
---|
170 |
|
---|
171 | Dictionary<string, int> frequencies;
|
---|
172 | if (!GenerateRules) {
|
---|
173 | Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel) solution.Model));
|
---|
174 | frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel) solution.Model);
|
---|
175 | }
|
---|
176 | else {
|
---|
177 | Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel) solution.Model, Problem.ProblemData, "M5TreeResult", true));
|
---|
178 | frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel) solution.Model);
|
---|
179 | Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel) solution.Model, Problem.ProblemData));
|
---|
180 | }
|
---|
181 |
|
---|
182 | //Variable frequencies
|
---|
183 | var sum = frequencies.Values.Sum();
|
---|
184 | sum = sum == 0 ? 1 : sum;
|
---|
185 | var impactArray = new DoubleArray(frequencies.Select(i => (double) i.Value / sum).ToArray()) {
|
---|
186 | ElementNames = frequencies.Select(i => i.Key)
|
---|
187 | };
|
---|
188 | Results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
|
---|
189 | }
|
---|
190 | #endregion
|
---|
191 | }
|
---|
192 | } |
---|