1 | using System;
|
---|
2 | using System.Threading;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Common;
|
---|
5 | using HeuristicLab.Core;
|
---|
6 | using HeuristicLab.Data;
|
---|
7 | using HeuristicLab.Optimization;
|
---|
8 | using HeuristicLab.Parameters;
|
---|
9 | using HEAL.Attic;
|
---|
10 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
12 | using System.Collections.Generic;
|
---|
13 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
14 | using System.Collections;
|
---|
15 |
|
---|
16 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
|
---|
17 |
|
---|
18 | [Item(Name = "FastFunctionExtraction", Description = "Implementation of the Fast Function Extraction (FFX) algorithm in C#.")]
|
---|
19 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
20 | [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")]
|
---|
21 | public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
22 |
|
---|
23 | #region constants
|
---|
24 | private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
|
---|
25 | private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
|
---|
26 | private static readonly double minHingeThr = 0.2;
|
---|
27 | private static readonly double maxHingeThr = 0.9;
|
---|
28 | private static readonly int numHingeThrs = 5;
|
---|
29 |
|
---|
30 | private const string ConsiderInteractionsParameterName = "Consider Interactions";
|
---|
31 | private const string ConsiderDenominationParameterName = "Consider Denomination";
|
---|
32 | private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
|
---|
33 | private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
|
---|
34 | private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
|
---|
35 | private const string LambdaParameterName = "Elastic Net Lambda";
|
---|
36 | private const string PenaltyParameterName = "Elastic Net Penalty";
|
---|
37 | private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
|
---|
38 |
|
---|
39 | #endregion
|
---|
40 |
|
---|
41 | #region parameters
|
---|
42 | public IValueParameter<BoolValue> ConsiderInteractionsParameter {
|
---|
43 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
|
---|
44 | }
|
---|
45 | public IValueParameter<BoolValue> ConsiderDenominationsParameter {
|
---|
46 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
|
---|
47 | }
|
---|
48 | public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
|
---|
49 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
|
---|
50 | }
|
---|
51 | public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
|
---|
52 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
|
---|
53 | }
|
---|
54 | public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
|
---|
55 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
|
---|
56 | }
|
---|
57 | public IValueParameter<DoubleValue> PenaltyParameter {
|
---|
58 | get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
|
---|
59 | }
|
---|
60 | public IValueParameter<DoubleValue> LambdaParameter {
|
---|
61 | get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
|
---|
62 | }
|
---|
63 | public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
|
---|
64 | get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
|
---|
65 | }
|
---|
66 | #endregion
|
---|
67 |
|
---|
68 | #region properties
|
---|
69 | public bool ConsiderInteractions {
|
---|
70 | get { return ConsiderInteractionsParameter.Value.Value; }
|
---|
71 | set { ConsiderInteractionsParameter.Value.Value = value; }
|
---|
72 | }
|
---|
73 | public bool ConsiderDenominations {
|
---|
74 | get { return ConsiderDenominationsParameter.Value.Value; }
|
---|
75 | set { ConsiderDenominationsParameter.Value.Value = value; }
|
---|
76 | }
|
---|
77 | public bool ConsiderExponentiations {
|
---|
78 | get { return ConsiderExponentiationsParameter.Value.Value; }
|
---|
79 | set { ConsiderExponentiationsParameter.Value.Value = value; }
|
---|
80 | }
|
---|
81 | public bool ConsiderNonlinearFunctions {
|
---|
82 | get { return ConsiderNonlinearFuncsParameter.Value.Value; }
|
---|
83 | set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
|
---|
84 | }
|
---|
85 | public bool ConsiderHingeFunctions {
|
---|
86 | get { return ConsiderHingeFuncsParameter.Value.Value; }
|
---|
87 | set { ConsiderHingeFuncsParameter.Value.Value = value; }
|
---|
88 | }
|
---|
89 | public double Penalty {
|
---|
90 | get { return PenaltyParameter.Value.Value; }
|
---|
91 | set { PenaltyParameter.Value.Value = value; }
|
---|
92 | }
|
---|
93 | public DoubleValue Lambda {
|
---|
94 | get { return LambdaParameter.Value; }
|
---|
95 | set { LambdaParameter.Value = value; }
|
---|
96 | }
|
---|
97 | public int MaxNumBasisFuncs {
|
---|
98 | get { return MaxNumBasisFuncsParameter.Value.Value; }
|
---|
99 | set { MaxNumBasisFuncsParameter.Value.Value = value; }
|
---|
100 | }
|
---|
101 | #endregion
|
---|
102 |
|
---|
103 | #region ctor
|
---|
104 |
|
---|
105 | [StorableConstructor]
|
---|
106 | private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
|
---|
107 | public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
|
---|
108 | }
|
---|
109 | public FastFunctionExtraction() : base() {
|
---|
110 | base.Problem = new RegressionProblem();
|
---|
111 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
|
---|
112 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
|
---|
113 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
|
---|
114 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
|
---|
115 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
|
---|
116 | Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
|
---|
117 | Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
|
---|
118 | Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
|
---|
119 | }
|
---|
120 |
|
---|
121 | [StorableHook(HookType.AfterDeserialization)]
|
---|
122 | private void AfterDeserialization() { }
|
---|
123 |
|
---|
124 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
125 | return new FastFunctionExtraction(this, cloner);
|
---|
126 | }
|
---|
127 | #endregion
|
---|
128 |
|
---|
129 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
130 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
131 |
|
---|
132 | protected override void Run(CancellationToken cancellationToken) {
|
---|
133 | var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
|
---|
134 | ConsiderNonlinearFunctions, ConsiderInteractions,
|
---|
135 | ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
|
---|
136 | );
|
---|
137 |
|
---|
138 | int i = 0;
|
---|
139 | var numBasesArr = numBases.ToArray();
|
---|
140 | var solutionsArr = new List<ISymbolicRegressionSolution>();
|
---|
141 |
|
---|
142 | foreach (var model in models) {
|
---|
143 | Results.Add(new Result(
|
---|
144 | "Num Bases: " + numBasesArr[i++],
|
---|
145 | model
|
---|
146 | ));
|
---|
147 | solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
|
---|
148 | }
|
---|
149 | Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
|
---|
150 | }
|
---|
151 |
|
---|
152 | public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
|
---|
153 | var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
|
---|
154 |
|
---|
155 | var allFFXModels = approaches
|
---|
156 | .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
|
---|
157 |
|
---|
158 | // Final Pareto filter over all generated models from all different approaches
|
---|
159 | var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
|
---|
160 |
|
---|
161 | numBases = nondominatedFFXModels
|
---|
162 | .Select(ffxModel => ffxModel.NumBases).ToArray();
|
---|
163 | return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
|
---|
164 | }
|
---|
165 |
|
---|
166 | private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
|
---|
167 | var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
|
---|
168 | var errors = ffxModels.Select(ffxModel => {
|
---|
169 | var originalValues = data.TargetVariableTestValues.ToArray();
|
---|
170 | var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
|
---|
171 | // do not create a regressionSolution here for better performance:
|
---|
172 | // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
|
---|
173 | var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
|
---|
174 | if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
|
---|
175 | return testMSE;
|
---|
176 | }).ToArray();
|
---|
177 |
|
---|
178 | int n = numBases.Length;
|
---|
179 | double[][] qualities = new double[n][];
|
---|
180 | for (int i = 0; i < n; i++) {
|
---|
181 | qualities[i] = new double[2];
|
---|
182 | qualities[i][0] = numBases[i];
|
---|
183 | qualities[i][1] = errors[i];
|
---|
184 | }
|
---|
185 |
|
---|
186 | return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
|
---|
187 | .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
|
---|
188 | }
|
---|
189 |
|
---|
190 | // Build FFX models
|
---|
191 | private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
|
---|
192 | // FFX Step 1
|
---|
193 | var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
|
---|
194 |
|
---|
195 |
|
---|
196 | // FFX Step 2
|
---|
197 | var funcsArr = basisFunctions.ToArray();
|
---|
198 | var elnetData = BFUtils.PrepareData(data, funcsArr);
|
---|
199 | var normalizedElnetData = BFUtils.Normalize(elnetData, out var X_avgs, out var X_stds, out var y_avg, out var y_std);
|
---|
200 |
|
---|
201 | ElasticNetLinearRegression.RunElasticNetLinearRegression(normalizedElnetData, approach.ElasticNetPenalty, out var _, out var _, out var _, out var candidateCoeffsNorm, out var interceptNorm, maxVars: approach.MaxNumBases);
|
---|
202 |
|
---|
203 | var coefs = RebiasCoefs(candidateCoeffsNorm, interceptNorm, X_avgs, X_stds, y_avg, y_std, out var intercept);
|
---|
204 |
|
---|
205 | // create models out of the learned coefficients
|
---|
206 | var ffxModels = GetModelsFromCoeffs(coefs, intercept, funcsArr, approach);
|
---|
207 |
|
---|
208 | // one last LS-optimization step on the training data
|
---|
209 | foreach (var ffxModel in ffxModels) {
|
---|
210 | if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
|
---|
211 | }
|
---|
212 | return ffxModels;
|
---|
213 | }
|
---|
214 |
|
---|
215 | private static double[,] RebiasCoefs(double[,] unbiasedCoefs, double[] unbiasedIntercepts, double [] X_avgs, double[] X_stds, double y_avg, double y_std, out double[] rebiasedIntercepts) {
|
---|
216 | var rows = unbiasedIntercepts.Length;
|
---|
217 | var cols = X_stds.Length;
|
---|
218 | var rebiasedCoefs = new double[rows,cols];
|
---|
219 | rebiasedIntercepts = new double[rows];
|
---|
220 |
|
---|
221 | for (int i = 0; i < rows; i++) {
|
---|
222 | var unbiasedIntercept = unbiasedIntercepts[i];
|
---|
223 | rebiasedIntercepts[i] = unbiasedIntercept * y_std + y_avg;
|
---|
224 |
|
---|
225 | for (int j = 0; j < cols; j++) {
|
---|
226 | rebiasedCoefs[i, j] = unbiasedCoefs[i, j] * y_std / X_stds[j];
|
---|
227 | rebiasedIntercepts[i] -= rebiasedCoefs[i, j] * X_avgs[j];
|
---|
228 | }
|
---|
229 | }
|
---|
230 | return rebiasedCoefs;
|
---|
231 | }
|
---|
232 |
|
---|
233 | // finds all models with unique combinations of basis functions
|
---|
234 | private static IEnumerable<FFXModel> GetModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
|
---|
235 | List<FFXModel> ffxModels = new List<FFXModel>();
|
---|
236 |
|
---|
237 | for (int i = 0; i < intercept.Length; i++) {
|
---|
238 | var row = candidateCoeffs.GetRow(i);
|
---|
239 | var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
|
---|
240 | if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
|
---|
241 | // ignore duplicate models (models with same combination of basis functions)
|
---|
242 | var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
|
---|
243 | ffxModels.Add(ffxModel);
|
---|
244 | }
|
---|
245 | return ffxModels;
|
---|
246 | }
|
---|
247 |
|
---|
248 | private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
|
---|
249 | var approaches = new List<Approach>();
|
---|
250 | var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
|
---|
251 |
|
---|
252 | // return true if ALL indices of true values of arr1 also have true values in arr2
|
---|
253 | bool follows(BitArray arr1, bool[] arr2) {
|
---|
254 | if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
|
---|
255 | for (int i = 0; i < arr1.Length; i++) {
|
---|
256 | if (arr1[i] && !arr2[i]) return false;
|
---|
257 | }
|
---|
258 | return true;
|
---|
259 | }
|
---|
260 |
|
---|
261 |
|
---|
262 | for (int i = 0; i < 32; i++) { // Iterate all combinations of 5 bools.
|
---|
263 | // map i to a bool array of length 5
|
---|
264 | var v = i;
|
---|
265 | int b = 0;
|
---|
266 | var arr = new BitArray(5);
|
---|
267 | var popCount = 0;
|
---|
268 | while (v>0) { if (v % 2 == 1) { arr[b++] = true; popCount++; } ; v /= 2; }
|
---|
269 |
|
---|
270 | if (!follows(arr, valids)) continue;
|
---|
271 | if (popCount >= 4) continue; // not too many features at once
|
---|
272 | if (arr[0] && arr[2]) continue; // never need both exponent and inter
|
---|
273 | approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
|
---|
274 | }
|
---|
275 | return approaches;
|
---|
276 | }
|
---|
277 | }
|
---|
278 | } |
---|