[17218] | 1 | using System;
|
---|
| 2 | using System.Threading;
|
---|
[17219] | 3 | using System.Linq;
|
---|
[17737] | 4 | using HeuristicLab.Common;
|
---|
| 5 | using HeuristicLab.Core;
|
---|
| 6 | using HeuristicLab.Data;
|
---|
| 7 | using HeuristicLab.Optimization;
|
---|
[17218] | 8 | using HeuristicLab.Parameters;
|
---|
| 9 | using HEAL.Attic;
|
---|
[17219] | 10 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
| 11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 12 | using System.Collections.Generic;
|
---|
[17227] | 13 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
[17218] | 14 |
|
---|
[17227] | 15 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
|
---|
[17218] | 16 |
|
---|
[17737] | 17 | [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
|
---|
| 18 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
| 19 | [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
|
---|
| 20 | public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
|
---|
[17227] | 21 |
|
---|
[17737] | 22 | #region constants
|
---|
| 23 | private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 };
|
---|
| 24 | private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None };
|
---|
| 25 | private static readonly double minHingeThr = 0.2;
|
---|
| 26 | private static readonly double maxHingeThr = 0.8;
|
---|
| 27 | private static readonly int numHingeThrs = 5;
|
---|
[17227] | 28 |
|
---|
[17737] | 29 | private const string ConsiderInteractionsParameterName = "Consider Interactions";
|
---|
| 30 | private const string ConsiderDenominationParameterName = "Consider Denomination";
|
---|
| 31 | private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
|
---|
| 32 | private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
|
---|
| 33 | private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions";
|
---|
| 34 | private const string LambdaParameterName = "Elastic Net Lambda";
|
---|
| 35 | private const string PenaltyParameterName = "Elastic Net Penalty";
|
---|
| 36 | private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions";
|
---|
[17227] | 37 |
|
---|
[17737] | 38 | #endregion
|
---|
[17227] | 39 |
|
---|
[17737] | 40 | #region parameters
|
---|
| 41 | public IValueParameter<BoolValue> ConsiderInteractionsParameter {
|
---|
| 42 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
|
---|
| 43 | }
|
---|
| 44 | public IValueParameter<BoolValue> ConsiderDenominationsParameter {
|
---|
| 45 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
|
---|
| 46 | }
|
---|
| 47 | public IValueParameter<BoolValue> ConsiderExponentiationsParameter {
|
---|
| 48 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
|
---|
| 49 | }
|
---|
| 50 | public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter {
|
---|
| 51 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
|
---|
| 52 | }
|
---|
| 53 | public IValueParameter<BoolValue> ConsiderHingeFuncsParameter {
|
---|
| 54 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
|
---|
| 55 | }
|
---|
| 56 | public IValueParameter<DoubleValue> PenaltyParameter {
|
---|
| 57 | get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
|
---|
| 58 | }
|
---|
| 59 | public IValueParameter<DoubleValue> LambdaParameter {
|
---|
| 60 | get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
|
---|
| 61 | }
|
---|
| 62 | public IValueParameter<IntValue> MaxNumBasisFuncsParameter {
|
---|
| 63 | get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; }
|
---|
| 64 | }
|
---|
| 65 | #endregion
|
---|
[17218] | 66 |
|
---|
[17737] | 67 | #region properties
|
---|
| 68 | public bool ConsiderInteractions {
|
---|
| 69 | get { return ConsiderInteractionsParameter.Value.Value; }
|
---|
| 70 | set { ConsiderInteractionsParameter.Value.Value = value; }
|
---|
| 71 | }
|
---|
| 72 | public bool ConsiderDenominations {
|
---|
| 73 | get { return ConsiderDenominationsParameter.Value.Value; }
|
---|
| 74 | set { ConsiderDenominationsParameter.Value.Value = value; }
|
---|
| 75 | }
|
---|
| 76 | public bool ConsiderExponentiations {
|
---|
| 77 | get { return ConsiderExponentiationsParameter.Value.Value; }
|
---|
| 78 | set { ConsiderExponentiationsParameter.Value.Value = value; }
|
---|
| 79 | }
|
---|
| 80 | public bool ConsiderNonlinearFunctions {
|
---|
| 81 | get { return ConsiderNonlinearFuncsParameter.Value.Value; }
|
---|
| 82 | set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
|
---|
| 83 | }
|
---|
| 84 | public bool ConsiderHingeFunctions {
|
---|
| 85 | get { return ConsiderHingeFuncsParameter.Value.Value; }
|
---|
| 86 | set { ConsiderHingeFuncsParameter.Value.Value = value; }
|
---|
| 87 | }
|
---|
| 88 | public double Penalty {
|
---|
| 89 | get { return PenaltyParameter.Value.Value; }
|
---|
| 90 | set { PenaltyParameter.Value.Value = value; }
|
---|
| 91 | }
|
---|
| 92 | public DoubleValue Lambda {
|
---|
| 93 | get { return LambdaParameter.Value; }
|
---|
| 94 | set { LambdaParameter.Value = value; }
|
---|
| 95 | }
|
---|
| 96 | public int MaxNumBasisFuncs {
|
---|
| 97 | get { return MaxNumBasisFuncsParameter.Value.Value; }
|
---|
| 98 | set { MaxNumBasisFuncsParameter.Value.Value = value; }
|
---|
| 99 | }
|
---|
| 100 | #endregion
|
---|
[17218] | 101 |
|
---|
[17737] | 102 | #region ctor
|
---|
[17218] | 103 |
|
---|
[17737] | 104 | [StorableConstructor]
|
---|
| 105 | private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
|
---|
| 106 | public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
|
---|
[17219] | 107 | }
|
---|
[17737] | 108 | public FastFunctionExtraction() : base() {
|
---|
| 109 | base.Problem = new RegressionProblem();
|
---|
| 110 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
|
---|
| 111 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
|
---|
| 112 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
|
---|
| 113 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
|
---|
| 114 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
|
---|
| 115 | Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20)));
|
---|
| 116 | Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
|
---|
| 117 | Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05)));
|
---|
[17219] | 118 | }
|
---|
[17218] | 119 |
|
---|
[17737] | 120 | [StorableHook(HookType.AfterDeserialization)]
|
---|
| 121 | private void AfterDeserialization() { }
|
---|
[17218] | 122 |
|
---|
[17737] | 123 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 124 | return new FastFunctionExtraction(this, cloner);
|
---|
| 125 | }
|
---|
| 126 | #endregion
|
---|
[17218] | 127 |
|
---|
[17737] | 128 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
| 129 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
[17218] | 130 |
|
---|
[17737] | 131 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 132 | var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations,
|
---|
| 133 | ConsiderNonlinearFunctions, ConsiderInteractions,
|
---|
| 134 | ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs
|
---|
| 135 | );
|
---|
[17218] | 136 |
|
---|
[17737] | 137 | int i = 0;
|
---|
| 138 | var numBasesArr = numBases.ToArray();
|
---|
| 139 | var solutionsArr = new List<ISymbolicRegressionSolution>();
|
---|
[17218] | 140 |
|
---|
[17737] | 141 | foreach (var model in models) {
|
---|
| 142 | Results.Add(new Result(
|
---|
| 143 | "Num Bases: " + numBasesArr[i++],
|
---|
| 144 | model
|
---|
| 145 | ));
|
---|
| 146 | solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData));
|
---|
| 147 | }
|
---|
| 148 | Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr)));
|
---|
| 149 | }
|
---|
[17218] | 150 |
|
---|
[17737] | 151 | public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) {
|
---|
| 152 | var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty);
|
---|
[17218] | 153 |
|
---|
[17737] | 154 | var allFFXModels = approaches
|
---|
| 155 | .SelectMany(approach => CreateFFXModels(data, approach)).ToList();
|
---|
[17218] | 156 |
|
---|
[17737] | 157 | // Final pareto filter over all generated models from all different approaches
|
---|
| 158 | var nondominatedFFXModels = NondominatedModels(data, allFFXModels);
|
---|
[17227] | 159 |
|
---|
[17737] | 160 | numBases = nondominatedFFXModels
|
---|
| 161 | .Select(ffxModel => ffxModel.NumBases).ToArray();
|
---|
| 162 | return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable));
|
---|
| 163 | }
|
---|
[17227] | 164 |
|
---|
[17737] | 165 | private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) {
|
---|
| 166 | var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray();
|
---|
| 167 | var errors = ffxModels.Select(ffxModel => {
|
---|
| 168 | var originalValues = data.TargetVariableTestValues.ToArray();
|
---|
| 169 | var estimatedValues = ffxModel.Simulate(data, data.TestIndices);
|
---|
| 170 | // do not create a regressionSolution here for better performance:
|
---|
| 171 | // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE
|
---|
| 172 | var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state);
|
---|
| 173 | if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE");
|
---|
| 174 | return testMSE;
|
---|
| 175 | }).ToArray();
|
---|
[17227] | 176 |
|
---|
[17737] | 177 | int n = numBases.Length;
|
---|
| 178 | double[][] qualities = new double[n][];
|
---|
| 179 | for (int i = 0; i < n; i++) {
|
---|
| 180 | qualities[i] = new double[2];
|
---|
| 181 | qualities[i][0] = numBases[i];
|
---|
| 182 | qualities[i][1] = errors[i];
|
---|
| 183 | }
|
---|
[17227] | 184 |
|
---|
[17737] | 185 | return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false })
|
---|
| 186 | .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases);
|
---|
| 187 | }
|
---|
[17227] | 188 |
|
---|
[17737] | 189 | // Build FFX models
|
---|
| 190 | private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) {
|
---|
| 191 | // FFX Step 1
|
---|
| 192 | var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray();
|
---|
[17227] | 193 |
|
---|
[17737] | 194 | var funcsArr = basisFunctions.ToArray();
|
---|
| 195 | var elnetData = BFUtils.PrepareData(data, funcsArr);
|
---|
[17218] | 196 |
|
---|
[17737] | 197 | // FFX Step 2
|
---|
| 198 | ElasticNetLinearRegression.RunElasticNetLinearRegression(elnetData, approach.ElnetPenalty, out var _, out var _, out var _, out var candidateCoeffs, out var intercept, maxVars: approach.MaxNumBases);
|
---|
[17218] | 199 |
|
---|
[17737] | 200 | // create models out of the learned coefficients
|
---|
| 201 | var ffxModels = GetUniqueModelsFromCoeffs(candidateCoeffs, intercept, funcsArr, approach);
|
---|
| 202 |
|
---|
| 203 | // one last LS-optimization step on the training data
|
---|
| 204 | foreach (var ffxModel in ffxModels) {
|
---|
[17738] | 205 | if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data);
|
---|
[17737] | 206 | }
|
---|
[17739] | 207 |
|
---|
[17737] | 208 | return ffxModels;
|
---|
[17219] | 209 | }
|
---|
[17218] | 210 |
|
---|
[17737] | 211 | // finds all models with unique combinations of basis functions
|
---|
| 212 | private static IEnumerable<FFXModel> GetUniqueModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) {
|
---|
| 213 | List<FFXModel> ffxModels = new List<FFXModel>();
|
---|
| 214 | List<int[]> bfCombinations = new List<int[]>();
|
---|
[17219] | 215 |
|
---|
[17737] | 216 | for (int i = 0; i < intercept.Length; i++) {
|
---|
| 217 | var row = candidateCoeffs.GetRow(i);
|
---|
| 218 | var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray();
|
---|
| 219 | if (nonzeroIndices.Count() > approach.MaxNumBases) continue;
|
---|
| 220 | // ignore duplicate models (models with same combination of basis functions)
|
---|
| 221 | if (bfCombinations.Any(arr => Enumerable.SequenceEqual(arr, nonzeroIndices))) continue;
|
---|
| 222 | var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx])));
|
---|
| 223 | bfCombinations.Add(nonzeroIndices);
|
---|
| 224 | ffxModels.Add(ffxModel);
|
---|
| 225 | }
|
---|
| 226 | return ffxModels;
|
---|
[17218] | 227 | }
|
---|
| 228 |
|
---|
[17737] | 229 | private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) {
|
---|
| 230 | var approaches = new List<Approach>();
|
---|
| 231 | var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions };
|
---|
[17219] | 232 |
|
---|
[17737] | 233 | // return true if ALL indices of true values of arr1 also have true values in arr2
|
---|
| 234 | bool follows(bool[] arr1, bool[] arr2) {
|
---|
| 235 | if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths");
|
---|
| 236 | for (int i = 0; i < arr1.Length; i++) {
|
---|
| 237 | if (arr1[i] && !arr2[i]) return false;
|
---|
| 238 | }
|
---|
| 239 | return true;
|
---|
| 240 | }
|
---|
[17219] | 241 |
|
---|
[17737] | 242 | for (int i = 0; i < 32; i++) {
|
---|
| 243 | // map i to a bool array of length 5
|
---|
| 244 | var arr = i.ToBoolArray(5);
|
---|
| 245 | if (!follows(arr, valids)) continue;
|
---|
| 246 | int sum = arr.Where(b => b).Count(); // how many features are enabled?
|
---|
| 247 | if (sum >= 4) continue; // not too many features at once
|
---|
| 248 | if (arr[0] && arr[2]) continue; // never need both exponent and inter
|
---|
| 249 | approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs));
|
---|
| 250 | }
|
---|
| 251 | return approaches;
|
---|
[17219] | 252 | }
|
---|
[17227] | 253 | }
|
---|
[17219] | 254 | } |
---|