Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs @ 17738

Last change on this file since 17738 was 17738, checked in by lleko, 4 years ago

#3022 implement ffx

File size: 14.4 KB
Line 
1using System;
2using System.Threading;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Optimization;
8using HeuristicLab.Parameters;
9using HEAL.Attic;
10using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
11using HeuristicLab.Problems.DataAnalysis;
12using System.Collections.Generic;
13using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
14
15namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
16
17    [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
18    [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
19    [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
20    public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
21
22        #region constants
23        private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
24        private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
25        private static readonly double minHingeThr = 0.2;
26        private static readonly double maxHingeThr = 0.8;
27        private static readonly int numHingeThrs = 5;
28
29        private const string ConsiderInteractionsParameterName = "Consider Interactions";
30        private const string ConsiderDenominationParameterName = "Consider Denomination";
31        private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
32        private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
33        private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
34        private const string LambdaParameterName = "Elastic Net Lambda";
35        private const string PenaltyParameterName = "Elastic Net Penalty";
36        private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
37
38        #endregion
39
40        #region parameters
41        public IValueParameter<BoolValue> ConsiderInteractionsParameter {
42            get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
43        }
44        public IValueParameter<BoolValue> ConsiderDenominationsParameter {
45            get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
46        }
47        public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
48            get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
49        }
50        public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
51            get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
52        }
53        public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
54            get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
55        }
56        public IValueParameter<DoubleValue> PenaltyParameter {
57            get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
58        }
59        public IValueParameter<DoubleValue> LambdaParameter {
60            get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
61        }
62        public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
63            get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
64        }
65        #endregion
66
67        #region properties
68        public bool ConsiderInteractions {
69            get { return ConsiderInteractionsParameter.Value.Value; }
70            set { ConsiderInteractionsParameter.Value.Value = value; }
71        }
72        public bool ConsiderDenominations {
73            get { return ConsiderDenominationsParameter.Value.Value; }
74            set { ConsiderDenominationsParameter.Value.Value = value; }
75        }
76        public bool ConsiderExponentiations {
77            get { return ConsiderExponentiationsParameter.Value.Value; }
78            set { ConsiderExponentiationsParameter.Value.Value = value; }
79        }
80        public bool ConsiderNonlinearFunctions {
81            get { return ConsiderNonlinearFuncsParameter.Value.Value; }
82            set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
83        }
84        public bool ConsiderHingeFunctions {
85            get { return ConsiderHingeFuncsParameter.Value.Value; }
86            set { ConsiderHingeFuncsParameter.Value.Value = value; }
87        }
88        public double Penalty {
89            get { return PenaltyParameter.Value.Value; }
90            set { PenaltyParameter.Value.Value = value; }
91        }
92        public DoubleValue Lambda {
93            get { return LambdaParameter.Value; }
94            set { LambdaParameter.Value = value; }
95        }
96        public int MaxNumBasisFuncs {
97            get { return MaxNumBasisFuncsParameter.Value.Value; }
98            set { MaxNumBasisFuncsParameter.Value.Value = value; }
99        }
100        #endregion
101
102        #region ctor
103
104        [StorableConstructor]
105        private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
106        public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
107        }
108        public FastFunctionExtraction() : base() {
109            base.Problem = new RegressionProblem();
110            Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
111            Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
112            Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
113            Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
114            Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
115            Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
116            Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
117            Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
118        }
119
120        [StorableHook(HookType.AfterDeserialization)]
121        private void AfterDeserialization() { }
122
123        public override IDeepCloneable Clone(Cloner cloner) {
124            return new FastFunctionExtraction(this, cloner);
125        }
126        #endregion
127
128        public override Type ProblemType { get { return typeof(RegressionProblem); } }
129        public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
130
131        protected override void Run(CancellationToken cancellationToken) {
132            var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
133                ConsiderNonlinearFunctions, ConsiderInteractions,
134                ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
135            );
136
137            int i = 0;
138            var numBasesArr = numBases.ToArray();
139            var solutionsArr = new List<ISymbolicRegressionSolution>();
140
141            foreach (var model in models) {
142                Results.Add(new Result(
143                    "Num Bases: " + numBasesArr[i++],
144                    model
145                ));
146                solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
147            }
148            Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
149        }
150
151        public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
152            var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
153
154            var allFFXModels = approaches
155                .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
156
157            // Final pareto filter over all generated models from all different approaches
158            var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
159
160            numBases = nondominatedFFXModels
161                .Select(ffxModel => ffxModel.NumBases).ToArray();
162            return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
163        }
164
165        private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
166            var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
167            var errors = ffxModels.Select(ffxModel => {
168                var originalValues = data.TargetVariableTestValues.ToArray();
169                var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
170                // do not create a regressionSolution here for better performance:
171                // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
172                var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
173                if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
174                return testMSE;
175            }).ToArray();
176
177            int n = numBases.Length;
178            double[][] qualities = new double[n][];
179            for (int i = 0; i < n; i++) {
180                qualities[i] = new double[2];
181                qualities[i][0] = numBases[i];
182                qualities[i][1] = errors[i];
183            }
184
185            return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
186                .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
187        }
188
189        // Build FFX models
190        private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
191            // FFX Step 1
192            var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
193
194            var funcsArr = basisFunctions.ToArray();
195            var elnetData = BFUtils.PrepareData(data, funcsArr);
196
197            // FFX Step 2
198            ElasticNetLinearRegression.RunElasticNetLinearRegression(elnetData, approach.ElnetPenalty, out var _, out var _, out var _, out var candidateCoeffs, out var intercept, maxVars: approach.MaxNumBases);
199
200            // create models out of the learned coefficients
201            var ffxModels = GetUniqueModelsFromCoeffs(candidateCoeffs, intercept, funcsArr, approach);
202             
203            // one last LS-optimization step on the training data
204            foreach (var ffxModel in ffxModels) {
205                if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
206            }
207            return ffxModels;
208        }
209
210        // finds all models with unique combinations of basis functions
211        private static IEnumerable<FFXModel> GetUniqueModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
212            List<FFXModel> ffxModels = new List<FFXModel>();
213            List<int[]> bfCombinations = new List<int[]>();
214
215            for (int i = 0; i < intercept.Length; i++) {
216                var row = candidateCoeffs.GetRow(i);
217                var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
218                if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
219                // ignore duplicate models (models with same combination of basis functions)
220                if (bfCombinations.Any(arr => Enumerable.SequenceEqual(arr, nonzeroIndices))) continue;
221                var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
222                bfCombinations.Add(nonzeroIndices);
223                ffxModels.Add(ffxModel);
224            }
225            return ffxModels;
226        }
227
228        private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
229            var approaches = new List<Approach>();
230            var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
231
232            // return true if ALL indices of true values of arr1 also have true values in arr2
233            bool follows(bool[] arr1, bool[] arr2) {
234                if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
235                for (int i = 0; i < arr1.Length; i++) {
236                    if (arr1[i] && !arr2[i]) return false;
237                }
238                return true;
239            }
240
241            for (int i = 0; i < 32; i++) {
242                // map i to a bool array of length 5
243                var arr = i.ToBoolArray(5);
244                if (!follows(arr, valids)) continue;
245                int sum = arr.Where(b => b).Count(); // how many features are enabled?
246                if (sum >= 4) continue; // not too many features at once
247                if (arr[0] && arr[2]) continue; // never need both exponent and inter
248                approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
249            }
250            return approaches;
251        }
252    }
253}
Note: See TracBrowser for help on using the repository browser.