1 | using System;
|
---|
2 | using System.Threading;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Common;
|
---|
5 | using HeuristicLab.Core;
|
---|
6 | using HeuristicLab.Data;
|
---|
7 | using HeuristicLab.Optimization;
|
---|
8 | using HeuristicLab.Parameters;
|
---|
9 | using HEAL.Attic;
|
---|
10 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
12 | using System.Collections.Generic;
|
---|
13 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
14 |
|
---|
15 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
|
---|
16 |
|
---|
17 | [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
|
---|
18 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
19 | [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
|
---|
20 | public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
|
---|
21 |
|
---|
22 | #region constants
|
---|
23 | private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
|
---|
24 | private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
|
---|
25 | private static readonly double minHingeThr = 0.2;
|
---|
26 | private static readonly double maxHingeThr = 0.8;
|
---|
27 | private static readonly int numHingeThrs = 5;
|
---|
28 |
|
---|
29 | private const string ConsiderInteractionsParameterName = "Consider Interactions";
|
---|
30 | private const string ConsiderDenominationParameterName = "Consider Denomination";
|
---|
31 | private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
|
---|
32 | private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
|
---|
33 | private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
|
---|
34 | private const string LambdaParameterName = "Elastic Net Lambda";
|
---|
35 | private const string PenaltyParameterName = "Elastic Net Penalty";
|
---|
36 | private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
|
---|
37 |
|
---|
38 | #endregion
|
---|
39 |
|
---|
40 | #region parameters
|
---|
41 | public IValueParameter<BoolValue> ConsiderInteractionsParameter {
|
---|
42 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
|
---|
43 | }
|
---|
44 | public IValueParameter<BoolValue> ConsiderDenominationsParameter {
|
---|
45 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
|
---|
46 | }
|
---|
47 | public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
|
---|
48 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
|
---|
49 | }
|
---|
50 | public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
|
---|
51 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
|
---|
52 | }
|
---|
53 | public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
|
---|
54 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
|
---|
55 | }
|
---|
56 | public IValueParameter<DoubleValue> PenaltyParameter {
|
---|
57 | get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
|
---|
58 | }
|
---|
59 | public IValueParameter<DoubleValue> LambdaParameter {
|
---|
60 | get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
|
---|
61 | }
|
---|
62 | public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
|
---|
63 | get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
|
---|
64 | }
|
---|
65 | #endregion
|
---|
66 |
|
---|
67 | #region properties
|
---|
68 | public bool ConsiderInteractions {
|
---|
69 | get { return ConsiderInteractionsParameter.Value.Value; }
|
---|
70 | set { ConsiderInteractionsParameter.Value.Value = value; }
|
---|
71 | }
|
---|
72 | public bool ConsiderDenominations {
|
---|
73 | get { return ConsiderDenominationsParameter.Value.Value; }
|
---|
74 | set { ConsiderDenominationsParameter.Value.Value = value; }
|
---|
75 | }
|
---|
76 | public bool ConsiderExponentiations {
|
---|
77 | get { return ConsiderExponentiationsParameter.Value.Value; }
|
---|
78 | set { ConsiderExponentiationsParameter.Value.Value = value; }
|
---|
79 | }
|
---|
80 | public bool ConsiderNonlinearFunctions {
|
---|
81 | get { return ConsiderNonlinearFuncsParameter.Value.Value; }
|
---|
82 | set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
|
---|
83 | }
|
---|
84 | public bool ConsiderHingeFunctions {
|
---|
85 | get { return ConsiderHingeFuncsParameter.Value.Value; }
|
---|
86 | set { ConsiderHingeFuncsParameter.Value.Value = value; }
|
---|
87 | }
|
---|
88 | public double Penalty {
|
---|
89 | get { return PenaltyParameter.Value.Value; }
|
---|
90 | set { PenaltyParameter.Value.Value = value; }
|
---|
91 | }
|
---|
92 | public DoubleValue Lambda {
|
---|
93 | get { return LambdaParameter.Value; }
|
---|
94 | set { LambdaParameter.Value = value; }
|
---|
95 | }
|
---|
96 | public int MaxNumBasisFuncs {
|
---|
97 | get { return MaxNumBasisFuncsParameter.Value.Value; }
|
---|
98 | set { MaxNumBasisFuncsParameter.Value.Value = value; }
|
---|
99 | }
|
---|
100 | #endregion
|
---|
101 |
|
---|
102 | #region ctor
|
---|
103 |
|
---|
104 | [StorableConstructor]
|
---|
105 | private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
|
---|
106 | public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
|
---|
107 | }
|
---|
108 | public FastFunctionExtraction() : base() {
|
---|
109 | base.Problem = new RegressionProblem();
|
---|
110 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
|
---|
111 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
|
---|
112 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
|
---|
113 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
|
---|
114 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
|
---|
115 | Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
|
---|
116 | Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
|
---|
117 | Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
|
---|
118 | }
|
---|
119 |
|
---|
120 | [StorableHook(HookType.AfterDeserialization)]
|
---|
121 | private void AfterDeserialization() { }
|
---|
122 |
|
---|
123 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
124 | return new FastFunctionExtraction(this, cloner);
|
---|
125 | }
|
---|
126 | #endregion
|
---|
127 |
|
---|
128 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
129 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
130 |
|
---|
131 | protected override void Run(CancellationToken cancellationToken) {
|
---|
132 | var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
|
---|
133 | ConsiderNonlinearFunctions, ConsiderInteractions,
|
---|
134 | ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
|
---|
135 | );
|
---|
136 |
|
---|
137 | int i = 0;
|
---|
138 | var numBasesArr = numBases.ToArray();
|
---|
139 | var solutionsArr = new List<ISymbolicRegressionSolution>();
|
---|
140 |
|
---|
141 | foreach (var model in models) {
|
---|
142 | Results.Add(new Result(
|
---|
143 | "Num Bases: " + numBasesArr[i++],
|
---|
144 | model
|
---|
145 | ));
|
---|
146 | solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
|
---|
147 | }
|
---|
148 | Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
|
---|
149 | }
|
---|
150 |
|
---|
151 | public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
|
---|
152 | var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
|
---|
153 |
|
---|
154 | var allFFXModels = approaches
|
---|
155 | .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
|
---|
156 |
|
---|
157 | // Final pareto filter over all generated models from all different approaches
|
---|
158 | var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
|
---|
159 |
|
---|
160 | numBases = nondominatedFFXModels
|
---|
161 | .Select(ffxModel => ffxModel.NumBases).ToArray();
|
---|
162 | return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
|
---|
163 | }
|
---|
164 |
|
---|
165 | private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
|
---|
166 | var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
|
---|
167 | var errors = ffxModels.Select(ffxModel => {
|
---|
168 | var originalValues = data.TargetVariableTestValues.ToArray();
|
---|
169 | var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
|
---|
170 | // do not create a regressionSolution here for better performance:
|
---|
171 | // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
|
---|
172 | var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
|
---|
173 | if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
|
---|
174 | return testMSE;
|
---|
175 | }).ToArray();
|
---|
176 |
|
---|
177 | int n = numBases.Length;
|
---|
178 | double[][] qualities = new double[n][];
|
---|
179 | for (int i = 0; i < n; i++) {
|
---|
180 | qualities[i] = new double[2];
|
---|
181 | qualities[i][0] = numBases[i];
|
---|
182 | qualities[i][1] = errors[i];
|
---|
183 | }
|
---|
184 |
|
---|
185 | return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
|
---|
186 | .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
|
---|
187 | }
|
---|
188 |
|
---|
189 | // Build FFX models
|
---|
190 | private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
|
---|
191 | // FFX Step 1
|
---|
192 | var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
|
---|
193 |
|
---|
194 | var funcsArr = basisFunctions.ToArray();
|
---|
195 | var elnetData = BFUtils.PrepareData(data, funcsArr);
|
---|
196 |
|
---|
197 | // FFX Step 2
|
---|
198 | ElasticNetLinearRegression.RunElasticNetLinearRegression(elnetData, approach.ElnetPenalty, out var _, out var _, out var _, out var candidateCoeffs, out var intercept, maxVars: approach.MaxNumBases);
|
---|
199 |
|
---|
200 | // create models out of the learned coefficients
|
---|
201 | var ffxModels = GetUniqueModelsFromCoeffs(candidateCoeffs, intercept, funcsArr, approach);
|
---|
202 |
|
---|
203 | // one last LS-optimization step on the training data
|
---|
204 | foreach (var ffxModel in ffxModels) {
|
---|
205 | if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
|
---|
206 | }
|
---|
207 | return ffxModels;
|
---|
208 | }
|
---|
209 |
|
---|
210 | // finds all models with unique combinations of basis functions
|
---|
211 | private static IEnumerable<FFXModel> GetUniqueModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
|
---|
212 | List<FFXModel> ffxModels = new List<FFXModel>();
|
---|
213 | List<int[]> bfCombinations = new List<int[]>();
|
---|
214 |
|
---|
215 | for (int i = 0; i < intercept.Length; i++) {
|
---|
216 | var row = candidateCoeffs.GetRow(i);
|
---|
217 | var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
|
---|
218 | if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
|
---|
219 | // ignore duplicate models (models with same combination of basis functions)
|
---|
220 | if (bfCombinations.Any(arr => Enumerable.SequenceEqual(arr, nonzeroIndices))) continue;
|
---|
221 | var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
|
---|
222 | bfCombinations.Add(nonzeroIndices);
|
---|
223 | ffxModels.Add(ffxModel);
|
---|
224 | }
|
---|
225 | return ffxModels;
|
---|
226 | }
|
---|
227 |
|
---|
228 | private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
|
---|
229 | var approaches = new List<Approach>();
|
---|
230 | var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
|
---|
231 |
|
---|
232 | // return true if ALL indices of true values of arr1 also have true values in arr2
|
---|
233 | bool follows(bool[] arr1, bool[] arr2) {
|
---|
234 | if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
|
---|
235 | for (int i = 0; i < arr1.Length; i++) {
|
---|
236 | if (arr1[i] && !arr2[i]) return false;
|
---|
237 | }
|
---|
238 | return true;
|
---|
239 | }
|
---|
240 |
|
---|
241 | for (int i = 0; i < 32; i++) {
|
---|
242 | // map i to a bool array of length 5
|
---|
243 | var arr = i.ToBoolArray(5);
|
---|
244 | if (!follows(arr, valids)) continue;
|
---|
245 | int sum = arr.Where(b => b).Count(); // how many features are enabled?
|
---|
246 | if (sum >= 4) continue; // not too many features at once
|
---|
247 | if (arr[0] && arr[2]) continue; // never need both exponent and inter
|
---|
248 | approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
|
---|
249 | }
|
---|
250 | return approaches;
|
---|
251 | }
|
---|
252 | }
|
---|
253 | } |
---|