[15430] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Threading;
|
---|
| 5 | using HeuristicLab.Common;
|
---|
| 6 | using HeuristicLab.Core;
|
---|
| 7 | using HeuristicLab.Data;
|
---|
| 8 | using HeuristicLab.Optimization;
|
---|
| 9 | using HeuristicLab.Parameters;
|
---|
| 10 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
| 11 | using HeuristicLab.PluginInfrastructure;
|
---|
| 12 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 13 | using HeuristicLab.Random;
|
---|
| 14 |
|
---|
| 15 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
| 16 | [StorableClass]
|
---|
| 17 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 95)]
|
---|
| 18 | [Item("M5RegressionTree", "A M5 regression tree / rule set classifier")]
|
---|
| 19 | public sealed class M5Regression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
| 20 | #region Parametername
|
---|
| 21 | private const string GenerateRulesParameterName = "GenerateRules";
|
---|
[15470] | 22 | private const string ImpurityParameterName = "Split";
|
---|
[15430] | 23 | private const string MinimalNodeSizeParameterName = "MinimalNodeSize";
|
---|
| 24 | private const string ModelTypeParameterName = "ModelType";
|
---|
| 25 | private const string PruningTypeParameterName = "PruningType";
|
---|
| 26 | private const string SeedParameterName = "Seed";
|
---|
| 27 | private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
|
---|
| 28 | #endregion
|
---|
| 29 |
|
---|
| 30 | #region Parameter properties
|
---|
| 31 | public IFixedValueParameter<BoolValue> GenerateRulesParameter {
|
---|
| 32 | get { return Parameters[GenerateRulesParameterName] as IFixedValueParameter<BoolValue>; }
|
---|
| 33 | }
|
---|
[15470] | 34 | public IConstrainedValueParameter<ISplitType> ImpurityParameter {
|
---|
| 35 | get { return Parameters[ImpurityParameterName] as IConstrainedValueParameter<ISplitType>; }
|
---|
[15430] | 36 | }
|
---|
| 37 | public IFixedValueParameter<IntValue> MinimalNodeSizeParameter {
|
---|
| 38 | get { return (IFixedValueParameter<IntValue>) Parameters[MinimalNodeSizeParameterName]; }
|
---|
| 39 | }
|
---|
| 40 | public IConstrainedValueParameter<ILeafType<IRegressionModel>> ModelTypeParameter {
|
---|
| 41 | get { return Parameters[ModelTypeParameterName] as IConstrainedValueParameter<ILeafType<IRegressionModel>>; }
|
---|
| 42 | }
|
---|
| 43 | public IConstrainedValueParameter<IPruningType> PruningTypeParameter {
|
---|
| 44 | get { return Parameters[PruningTypeParameterName] as IConstrainedValueParameter<IPruningType>; }
|
---|
| 45 | }
|
---|
| 46 | public IFixedValueParameter<IntValue> SeedParameter {
|
---|
| 47 | get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
|
---|
| 48 | }
|
---|
| 49 | public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
|
---|
| 50 | get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
|
---|
| 51 | }
|
---|
| 52 | #endregion
|
---|
| 53 |
|
---|
| 54 | #region Properties
|
---|
| 55 | public bool GenerateRules {
|
---|
| 56 | get { return GenerateRulesParameter.Value.Value; }
|
---|
| 57 | }
|
---|
[15470] | 58 | public ISplitType Split {
|
---|
[15430] | 59 | get { return ImpurityParameter.Value; }
|
---|
| 60 | }
|
---|
| 61 | public int MinimalNodeSize {
|
---|
| 62 | get { return MinimalNodeSizeParameter.Value.Value; }
|
---|
| 63 | }
|
---|
| 64 | public ILeafType<IRegressionModel> LeafType {
|
---|
| 65 | get { return ModelTypeParameter.Value; }
|
---|
| 66 | }
|
---|
| 67 | public IPruningType PruningType {
|
---|
| 68 | get { return PruningTypeParameter.Value; }
|
---|
| 69 | }
|
---|
| 70 | public int Seed {
|
---|
| 71 | get { return SeedParameter.Value.Value; }
|
---|
| 72 | }
|
---|
| 73 | public bool SetSeedRandomly {
|
---|
| 74 | get { return SetSeedRandomlyParameter.Value.Value; }
|
---|
| 75 | }
|
---|
| 76 | #endregion
|
---|
| 77 |
|
---|
| 78 | #region Constructors and Cloning
|
---|
| 79 | [StorableConstructor]
|
---|
| 80 | private M5Regression(bool deserializing) : base(deserializing) { }
|
---|
| 81 | private M5Regression(M5Regression original, Cloner cloner) : base(original, cloner) { }
|
---|
| 82 | public M5Regression() {
|
---|
| 83 | var modelSet = new ItemSet<ILeafType<IRegressionModel>>(ApplicationManager.Manager.GetInstances<ILeafType<IRegressionModel>>());
|
---|
| 84 | var pruningSet = new ItemSet<IPruningType>(ApplicationManager.Manager.GetInstances<IPruningType>());
|
---|
[15470] | 85 | var impuritySet = new ItemSet<ISplitType>(ApplicationManager.Manager.GetInstances<ISplitType>());
|
---|
[15430] | 86 | Parameters.Add(new FixedValueParameter<BoolValue>(GenerateRulesParameterName, "Whether a set of rules or a decision tree shall be created", new BoolValue(true)));
|
---|
[15470] | 87 | Parameters.Add(new ConstrainedValueParameter<ISplitType>(ImpurityParameterName, "The type of split function used to create node splits", impuritySet, impuritySet.OfType<OrderSplitType>().First()));
|
---|
[15430] | 88 | Parameters.Add(new FixedValueParameter<IntValue>(MinimalNodeSizeParameterName, "The minimal number of samples in a leaf node", new IntValue(1)));
|
---|
| 89 | Parameters.Add(new ConstrainedValueParameter<ILeafType<IRegressionModel>>(ModelTypeParameterName, "The type of model used for the nodes", modelSet, modelSet.OfType<LinearLeaf>().First()));
|
---|
| 90 | Parameters.Add(new ConstrainedValueParameter<IPruningType>(PruningTypeParameterName, "The type of pruning used", pruningSet, pruningSet.OfType<M5LeafPruning>().First()));
|
---|
| 91 | Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
|
---|
| 92 | Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
|
---|
| 93 | Problem = new RegressionProblem();
|
---|
| 94 | }
|
---|
| 95 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 96 | return new M5Regression(this, cloner);
|
---|
| 97 | }
|
---|
| 98 | #endregion
|
---|
| 99 |
|
---|
| 100 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 101 | var random = new MersenneTwister();
|
---|
| 102 | if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
|
---|
| 103 | random.Reset(Seed);
|
---|
[15470] | 104 | var solution = CreateM5RegressionSolution(Problem.ProblemData, random, LeafType, Split, PruningType, cancellationToken, MinimalNodeSize, GenerateRules, Results);
|
---|
[15430] | 105 | AnalyzeSolution(solution);
|
---|
| 106 | }
|
---|
| 107 |
|
---|
| 108 | #region Static Interface
|
---|
| 109 | public static IRegressionSolution CreateM5RegressionSolution(IRegressionProblemData problemData, IRandom random,
|
---|
[15470] | 110 | ILeafType<IRegressionModel> leafType = null, ISplitType splitType = null, IPruningType pruningType = null,
|
---|
[15430] | 111 | CancellationToken? cancellationToken = null, int minNumInstances = 4, bool generateRules = false, ResultCollection results = null) {
|
---|
| 112 | //set default values
|
---|
| 113 | if (leafType == null) leafType = new LinearLeaf();
|
---|
[15470] | 114 | if (splitType == null) splitType = new OrderSplitType();
|
---|
[15430] | 115 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
| 116 | if (pruningType == null) pruningType = new M5LeafPruning();
|
---|
| 117 |
|
---|
| 118 |
|
---|
| 119 | var doubleVars = new HashSet<string>(problemData.Dataset.DoubleVariables);
|
---|
| 120 | var vars = problemData.AllowedInputVariables.Concat(new[] {problemData.TargetVariable}).ToArray();
|
---|
| 121 | if (vars.Any(v => !doubleVars.Contains(v))) throw new NotSupportedException("M5 regression does not support non-double valued input or output features.");
|
---|
| 122 |
|
---|
| 123 | var values = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();
|
---|
| 124 | if (values.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
|
---|
| 125 | throw new NotSupportedException("M5 regression does not support NaN or infinity values in the input dataset.");
|
---|
| 126 | var trainingData = new Dataset(vars, values);
|
---|
| 127 | var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);
|
---|
| 128 | pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
|
---|
| 129 | pd.TrainingPartition.Start = 0;
|
---|
| 130 |
|
---|
| 131 | //create & build Model
|
---|
[15470] | 132 | var m5Params = new M5CreationParameters(pruningType, minNumInstances, leafType, pd, random, splitType, results);
|
---|
[15430] | 133 |
|
---|
| 134 | IReadOnlyList<int> t, h;
|
---|
| 135 | pruningType.GenerateHoldOutSet(problemData.TrainingIndices.ToArray(), random, out t, out h);
|
---|
| 136 |
|
---|
| 137 | if (generateRules) {
|
---|
| 138 | IM5MetaModel model = M5RuleSetModel.CreateRuleModel(problemData.TargetVariable, m5Params);
|
---|
| 139 | model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
|
---|
| 140 | return model.CreateRegressionSolution(problemData);
|
---|
| 141 | }
|
---|
| 142 | else {
|
---|
| 143 | IM5MetaModel model = M5TreeModel.CreateTreeModel(problemData.TargetVariable, m5Params);
|
---|
| 144 | model.BuildClassifier(t, h, m5Params, cancellationToken.Value);
|
---|
| 145 | return model.CreateRegressionSolution(problemData);
|
---|
| 146 | }
|
---|
| 147 | }
|
---|
| 148 |
|
---|
| 149 | public static void UpdateM5Model(M5TreeModel model, IRegressionProblemData problemData, IRandom random,
|
---|
[15470] | 150 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
| 151 | UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
|
---|
[15430] | 152 | }
|
---|
| 153 |
|
---|
| 154 | public static void UpdateM5Model(M5RuleSetModel model, IRegressionProblemData problemData, IRandom random,
|
---|
[15470] | 155 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
| 156 | UpdateM5Model(model as IM5MetaModel, problemData, random, leafType, cancellationToken);
|
---|
[15430] | 157 | }
|
---|
| 158 |
|
---|
| 159 | private static void UpdateM5Model(IM5MetaModel model, IRegressionProblemData problemData, IRandom random,
|
---|
[15470] | 160 | ILeafType<IRegressionModel> leafType = null, CancellationToken? cancellationToken = null) {
|
---|
[15430] | 161 | if (cancellationToken == null) cancellationToken = CancellationToken.None;
|
---|
[15470] | 162 | var m5Params = new M5UpdateParameters(leafType, problemData, random);
|
---|
[15430] | 163 | model.UpdateModel(problemData.TrainingIndices.ToList(), m5Params, cancellationToken.Value);
|
---|
| 164 | }
|
---|
| 165 | #endregion
|
---|
| 166 |
|
---|
| 167 | #region Helpers
|
---|
| 168 | private void AnalyzeSolution(IRegressionSolution solution) {
|
---|
| 169 | Results.Add(new Result("RegressionSolution", (IItem) solution.Clone()));
|
---|
| 170 |
|
---|
| 171 | Dictionary<string, int> frequencies;
|
---|
| 172 | if (!GenerateRules) {
|
---|
| 173 | Results.Add(M5Analyzer.CreateLeafDepthHistogram((M5TreeModel) solution.Model));
|
---|
| 174 | frequencies = M5Analyzer.GetTreeVariableFrequences((M5TreeModel) solution.Model);
|
---|
| 175 | }
|
---|
| 176 | else {
|
---|
[15470] | 177 | Results.Add(M5Analyzer.CreateRulesResult((M5RuleSetModel) solution.Model, Problem.ProblemData, "M5TreeResult", true));
|
---|
[15430] | 178 | frequencies = M5Analyzer.GetRuleVariableFrequences((M5RuleSetModel) solution.Model);
|
---|
| 179 | Results.Add(M5Analyzer.CreateCoverageDiagram((M5RuleSetModel) solution.Model, Problem.ProblemData));
|
---|
| 180 | }
|
---|
| 181 |
|
---|
| 182 | //Variable frequencies
|
---|
| 183 | var sum = frequencies.Values.Sum();
|
---|
| 184 | sum = sum == 0 ? 1 : sum;
|
---|
| 185 | var impactArray = new DoubleArray(frequencies.Select(i => (double) i.Value / sum).ToArray()) {
|
---|
| 186 | ElementNames = frequencies.Select(i => i.Key)
|
---|
| 187 | };
|
---|
| 188 | Results.Add(new Result("Variable Frequences", "relative frequencies of variables in rules and tree nodes", impactArray));
|
---|
| 189 | }
|
---|
| 190 | #endregion
|
---|
| 191 | }
|
---|
| 192 | } |
---|