Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs @ 17986

Last change on this file since 17986 was 17779, checked in by gkronber, 4 years ago

#3022: made a few changes while reviewing the code.

File size: 15.4 KB
Line 
1using System;
2using System.Threading;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Optimization;
8using HeuristicLab.Parameters;
9using HEAL.Attic;
10using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
11using HeuristicLab.Problems.DataAnalysis;
12using System.Collections.Generic;
13using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
14using System.Collections;
15
16namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
17
18    [Item(Name = "FastFunctionExtraction", Description = "Implementation of the Fast Function Extraction (FFX) algorithm in C#.")]
19    [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
20    [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")]
21    public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<IRegressionProblem> {
22
23        #region constants
24        private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
25        private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
26        private static readonly double minHingeThr = 0.2;
27        private static readonly double maxHingeThr = 0.9;
28        private static readonly int numHingeThrs = 5;
29
30        private const string ConsiderInteractionsParameterName = "Consider Interactions";
31        private const string ConsiderDenominationParameterName = "Consider Denomination";
32        private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
33        private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
34        private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
35        private const string LambdaParameterName = "Elastic Net Lambda";
36        private const string PenaltyParameterName = "Elastic Net Penalty";
37        private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
38
39        #endregion
40
41        #region parameters
42        public IValueParameter<BoolValue> ConsiderInteractionsParameter {
43            get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
44        }
45        public IValueParameter<BoolValue> ConsiderDenominationsParameter {
46            get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
47        }
48        public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
49            get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
50        }
51        public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
52            get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
53        }
54        public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
55            get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
56        }
57        public IValueParameter<DoubleValue> PenaltyParameter {
58            get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
59        }
60        public IValueParameter<DoubleValue> LambdaParameter {
61            get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
62        }
63        public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
64            get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
65        }
66        #endregion
67
68        #region properties
69        public bool ConsiderInteractions {
70            get { return ConsiderInteractionsParameter.Value.Value; }
71            set { ConsiderInteractionsParameter.Value.Value = value; }
72        }
73        public bool ConsiderDenominations {
74            get { return ConsiderDenominationsParameter.Value.Value; }
75            set { ConsiderDenominationsParameter.Value.Value = value; }
76        }
77        public bool ConsiderExponentiations {
78            get { return ConsiderExponentiationsParameter.Value.Value; }
79            set { ConsiderExponentiationsParameter.Value.Value = value; }
80        }
81        public bool ConsiderNonlinearFunctions {
82            get { return ConsiderNonlinearFuncsParameter.Value.Value; }
83            set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
84        }
85        public bool ConsiderHingeFunctions {
86            get { return ConsiderHingeFuncsParameter.Value.Value; }
87            set { ConsiderHingeFuncsParameter.Value.Value = value; }
88        }
89        public double Penalty {
90            get { return PenaltyParameter.Value.Value; }
91            set { PenaltyParameter.Value.Value = value; }
92        }
93        public DoubleValue Lambda {
94            get { return LambdaParameter.Value; }
95            set { LambdaParameter.Value = value; }
96        }
97        public int MaxNumBasisFuncs {
98            get { return MaxNumBasisFuncsParameter.Value.Value; }
99            set { MaxNumBasisFuncsParameter.Value.Value = value; }
100        }
101        #endregion
102
103        #region ctor
104
105        [StorableConstructor]
106        private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
107        public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
108        }
109        public FastFunctionExtraction() : base() {
110            base.Problem = new RegressionProblem();
111            Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
112            Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
113            Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
114            Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
115            Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
116            Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
117            Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
118            Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
119        }
120
121        [StorableHook(HookType.AfterDeserialization)]
122        private void AfterDeserialization() { }
123
124        public override IDeepCloneable Clone(Cloner cloner) {
125            return new FastFunctionExtraction(this, cloner);
126        }
127        #endregion
128
129        public override Type ProblemType { get { return typeof(RegressionProblem); } }
130        public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
131
132        protected override void Run(CancellationToken cancellationToken) {
133            var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
134                ConsiderNonlinearFunctions, ConsiderInteractions,
135                ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
136            );
137
138            int i = 0;
139            var numBasesArr = numBases.ToArray();
140            var solutionsArr = new List<ISymbolicRegressionSolution>();
141
142            foreach (var model in models) {
143                Results.Add(new Result(
144                    "Num Bases: " + numBasesArr[i++],
145                    model
146                ));
147                solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
148            }
149            Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
150        }
151
152        public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
153            var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
154
155            var allFFXModels = approaches
156                .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
157
158            // Final Pareto filter over all generated models from all different approaches
159            var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
160
161            numBases = nondominatedFFXModels
162                .Select(ffxModel => ffxModel.NumBases).ToArray();
163            return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
164        }
165
166        private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
167            var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
168            var errors = ffxModels.Select(ffxModel => {
169                var originalValues = data.TargetVariableTestValues.ToArray();
170                var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
171                // do not create a regressionSolution here for better performance:
172                // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
173                var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
174                if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
175                return testMSE;
176            }).ToArray();
177
178            int n = numBases.Length;
179            double[][] qualities = new double[n][];
180            for (int i = 0; i < n; i++) {
181                qualities[i] = new double[2];
182                qualities[i][0] = numBases[i];
183                qualities[i][1] = errors[i];
184            }
185
186            return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
187                .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
188        }
189
190        // Build FFX models
191        private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
192            // FFX Step 1
193            var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
194
195
196            // FFX Step 2
197            var funcsArr = basisFunctions.ToArray();
198            var elnetData = BFUtils.PrepareData(data, funcsArr);
199            var normalizedElnetData = BFUtils.Normalize(elnetData, out var X_avgs, out var X_stds, out var y_avg, out var y_std);
200           
201            ElasticNetLinearRegression.RunElasticNetLinearRegression(normalizedElnetData, approach.ElasticNetPenalty, out var _, out var _, out var _, out var candidateCoeffsNorm, out var interceptNorm, maxVars: approach.MaxNumBases);
202
203            var coefs = RebiasCoefs(candidateCoeffsNorm, interceptNorm, X_avgs, X_stds, y_avg, y_std, out var intercept);
204
205            // create models out of the learned coefficients
206            var ffxModels = GetModelsFromCoeffs(coefs, intercept, funcsArr, approach);
207
208            // one last LS-optimization step on the training data
209            foreach (var ffxModel in ffxModels) {
210                if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
211            }
212            return ffxModels;
213        }
214
215        private static double[,] RebiasCoefs(double[,] unbiasedCoefs, double[] unbiasedIntercepts, double [] X_avgs, double[] X_stds, double y_avg, double y_std, out double[] rebiasedIntercepts) {
216            var rows = unbiasedIntercepts.Length;
217            var cols = X_stds.Length;
218            var rebiasedCoefs = new double[rows,cols];
219            rebiasedIntercepts = new double[rows];
220
221            for (int i = 0; i < rows; i++) {
222                var unbiasedIntercept = unbiasedIntercepts[i];
223                rebiasedIntercepts[i] = unbiasedIntercept * y_std + y_avg;
224
225                for (int j = 0; j < cols; j++) {
226                    rebiasedCoefs[i, j] = unbiasedCoefs[i, j] * y_std / X_stds[j];
227                    rebiasedIntercepts[i] -= rebiasedCoefs[i, j] * X_avgs[j];
228                }
229            }
230            return rebiasedCoefs;
231        }
232
233        // finds all models with unique combinations of basis functions
234        private static IEnumerable<FFXModel> GetModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
235            List<FFXModel> ffxModels = new List<FFXModel>();
236
237            for (int i = 0; i < intercept.Length; i++) {
238                var row = candidateCoeffs.GetRow(i);
239                var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
240                if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
241                // ignore duplicate models (models with same combination of basis functions)
242                var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
243                ffxModels.Add(ffxModel);
244            }
245            return ffxModels;
246        }
247
248        private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
249            var approaches = new List<Approach>();
250            var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
251
252            // return true if ALL indices of true values of arr1 also have true values in arr2
253            bool follows(BitArray arr1, bool[] arr2) {
254                if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
255                for (int i = 0; i < arr1.Length; i++) {
256                    if (arr1[i] && !arr2[i]) return false;
257                }
258                return true;
259            }
260
261
262            for (int i = 0; i < 32; i++) { // Iterate all combinations of 5 bools.
263                // map i to a bool array of length 5               
264                var v = i;
265                int b = 0;
266                var arr = new BitArray(5);
267                var popCount = 0;
268                while (v>0) { if (v % 2 == 1) { arr[b++] = true; popCount++; } ; v /= 2; }
269
270                if (!follows(arr, valids)) continue;
271                if (popCount >= 4) continue; // not too many features at once
272                if (arr[0] && arr[2]) continue; // never need both exponent and inter
273                approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
274            }
275            return approaches;
276        }
277    }
278}
Note: See TracBrowser for help on using the repository browser.