- Timestamp:
- 08/29/20 09:23:06 (4 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs
r17227 r17737 2 2 using System.Threading; 3 3 using System.Linq; 4 using HeuristicLab.Common; // required for parameters collection 5 using HeuristicLab.Core; // required for parameters collection 6 using HeuristicLab.Data; // IntValue, ... 7 using HeuristicLab.Encodings.BinaryVectorEncoding; 8 using HeuristicLab.Optimization; // BasicAlgorithm 4 using HeuristicLab.Common; 5 using HeuristicLab.Core; 6 using HeuristicLab.Data; 7 using HeuristicLab.Optimization; 9 8 using HeuristicLab.Parameters; 10 using HeuristicLab.Problems.Binary;11 using HeuristicLab.Random; // MersenneTwister12 9 using HEAL.Attic; 13 10 using HeuristicLab.Algorithms.DataAnalysis.Glmnet; 14 11 using HeuristicLab.Problems.DataAnalysis; 15 12 using System.Collections.Generic; 16 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;17 using System.Collections;18 using System.Diagnostics;19 using HeuristicLab.Problems.DataAnalysis.Symbolic;20 13 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; 21 using HeuristicLab.Analysis;22 using HeuristicLab.Collections;23 14 24 15 namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction { 25 16 26 [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")] 27 [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)] 28 [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive 29 public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> { 30 31 private static readonly double[] exponents = { 0.5, 1, 2 }; 32 private static readonly OpCode[] nonlinFuncs = { OpCode.Absolute, OpCode.Log, OpCode.Sin, OpCode.Cos }; 33 34 private static readonly BidirectionalDictionary<OpCode, string> OpCodeToString = new BidirectionalDictionary<OpCode, string> { 35 { OpCode.Log, "LOG" }, 36 { OpCode.Absolute, "ABS"}, 37 { OpCode.Sin, "SIN"}, 38 { OpCode.Cos, "COS"}, 39 { OpCode.Square, "SQR"}, 40 { OpCode.SquareRoot, "SQRT"}, 41 { OpCode.Cube, "CUBE"}, 42 { OpCode.CubeRoot, "CUBEROOT"} 43 }; 44 45 private const string ConsiderInteractionsParameterName = "Consider Interactions"; 46 private const string ConsiderDenominationParameterName = "Consider Denomination"; 47 private const string ConsiderExponentiationParameterName = "Consider Exponentiation"; 48 private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions"; 49 private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions"; 50 private const string PenaltyParameterName = "Penalty"; 51 private const string LambdaParameterName = "Lambda"; 52 private const string NonlinearFuncsParameterName = "Nonlinear Functions"; 53 54 #region parameters 55 public IValueParameter<BoolValue> ConsiderInteractionsParameter 56 { 57 get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; } 17 [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")] 18 [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)] 19 [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive 20 public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> { 21 22 #region constants 23 private static readonly HashSet<double> exponents = new HashSet<double> { -1.0, -0.5, +0.5, +1.0 }; 24 private static readonly HashSet<NonlinearOperator> nonlinFuncs = new HashSet<NonlinearOperator> { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None }; 25 private static readonly double minHingeThr = 0.2; 26 private static readonly double maxHingeThr = 0.8; 27 private static readonly int numHingeThrs = 5; 28 29 private const string ConsiderInteractionsParameterName = "Consider Interactions"; 30 private const string ConsiderDenominationParameterName = "Consider Denomination"; 31 private const string ConsiderExponentiationParameterName = "Consider Exponentiation"; 32 private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions"; 33 private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions"; 34 private const string LambdaParameterName = "Elastic Net Lambda"; 35 private const string PenaltyParameterName = "Elastic Net Penalty"; 36 private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions"; 37 38 #endregion 39 40 #region parameters 41 public IValueParameter<BoolValue> ConsiderInteractionsParameter { 42 get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; } 43 } 44 public IValueParameter<BoolValue> ConsiderDenominationsParameter { 45 get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; } 46 } 47 public IValueParameter<BoolValue> ConsiderExponentiationsParameter { 48 get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; } 49 } 50 public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter { 51 get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; } 52 } 53 public IValueParameter<BoolValue> ConsiderHingeFuncsParameter { 54 get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; } 55 } 56 public IValueParameter<DoubleValue> PenaltyParameter { 57 get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; } 58 } 59 public IValueParameter<DoubleValue> LambdaParameter { 60 get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; } 61 } 62 public IValueParameter<IntValue> MaxNumBasisFuncsParameter { 63 get { return (IValueParameter<IntValue>)Parameters[MaxNumBasisFuncsParameterName]; } 64 } 65 #endregion 66 67 #region properties 68 public bool ConsiderInteractions { 69 get { return ConsiderInteractionsParameter.Value.Value; } 70 set { ConsiderInteractionsParameter.Value.Value = value; } 71 } 72 public bool ConsiderDenominations { 73 get { return ConsiderDenominationsParameter.Value.Value; } 74 set { ConsiderDenominationsParameter.Value.Value = value; } 75 } 76 public bool ConsiderExponentiations { 77 get { return ConsiderExponentiationsParameter.Value.Value; } 78 set { ConsiderExponentiationsParameter.Value.Value = value; } 79 } 80 public bool ConsiderNonlinearFunctions { 81 get { return ConsiderNonlinearFuncsParameter.Value.Value; } 82 set { ConsiderNonlinearFuncsParameter.Value.Value = value; } 83 } 84 public bool ConsiderHingeFunctions { 85 get { return ConsiderHingeFuncsParameter.Value.Value; } 86 set { ConsiderHingeFuncsParameter.Value.Value = value; } 87 } 88 public double Penalty { 89 get { return PenaltyParameter.Value.Value; } 90 set { PenaltyParameter.Value.Value = value; } 91 } 92 public DoubleValue Lambda { 93 get { return LambdaParameter.Value; } 94 set { LambdaParameter.Value = value; } 95 } 96 public int MaxNumBasisFuncs { 97 get { return MaxNumBasisFuncsParameter.Value.Value; } 98 set { MaxNumBasisFuncsParameter.Value.Value = value; } 99 } 100 #endregion 101 102 #region ctor 103 104 [StorableConstructor] 105 private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { } 106 public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) { 107 } 108 public FastFunctionExtraction() : base() { 109 base.Problem = new RegressionProblem(); 110 Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true))); 111 Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true))); 112 Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true))); 113 Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true))); 114 Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true))); 115 Parameters.Add(new ValueParameter<IntValue>(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20))); 116 Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas")); 117 Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05))); 118 } 119 120 [StorableHook(HookType.AfterDeserialization)] 121 private void AfterDeserialization() { } 122 123 public override IDeepCloneable Clone(Cloner cloner) { 124 return new FastFunctionExtraction(this, cloner); 125 } 126 #endregion 127 128 public override Type ProblemType { get { return typeof(RegressionProblem); } } 129 public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } } 130 131 protected override void Run(CancellationToken cancellationToken) { 132 var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations, 133 ConsiderNonlinearFunctions, ConsiderInteractions, 134 ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs 135 ); 136 137 int i = 0; 138 var numBasesArr = numBases.ToArray(); 139 var solutionsArr = new List<ISymbolicRegressionSolution>(); 140 141 foreach (var model in models) { 142 Results.Add(new Result( 143 "Num Bases: " + numBasesArr[i++], 144 model 145 )); 146 solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData)); 147 } 148 Results.Add(new Result("Model Accuracies", new ItemCollection<ISymbolicRegressionSolution>(solutionsArr))); 149 } 150 151 public static IEnumerable<ISymbolicRegressionModel> Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable<int> numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) { 152 var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty); 153 154 var allFFXModels = approaches 155 .SelectMany(approach => CreateFFXModels(data, approach)).ToList(); 156 157 // Final pareto filter over all generated models from all different approaches 158 var nondominatedFFXModels = NondominatedModels(data, allFFXModels); 159 160 numBases = nondominatedFFXModels 161 .Select(ffxModel => ffxModel.NumBases).ToArray(); 162 return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable)); 163 } 164 165 private static IEnumerable<FFXModel> NondominatedModels(IRegressionProblemData data, IEnumerable<FFXModel> ffxModels) { 166 var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray(); 167 var errors = ffxModels.Select(ffxModel => { 168 var originalValues = data.TargetVariableTestValues.ToArray(); 169 var estimatedValues = ffxModel.Simulate(data, data.TestIndices); 170 // do not create a regressionSolution here for better performance: 171 // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE 172 var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state); 173 if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE"); 174 return testMSE; 175 }).ToArray(); 176 177 int n = numBases.Length; 178 double[][] qualities = new double[n][]; 179 for (int i = 0; i < n; i++) { 180 qualities[i] = new double[2]; 181 qualities[i][0] = numBases[i]; 182 qualities[i][1] = errors[i]; 183 } 184 185 return DominationCalculator<FFXModel>.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false }) 186 .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases); 187 } 188 189 // Build FFX models 190 private static IEnumerable<FFXModel> CreateFFXModels(IRegressionProblemData data, Approach approach) { 191 // FFX Step 1 192 var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray(); 193 194 var funcsArr = basisFunctions.ToArray(); 195 var elnetData = BFUtils.PrepareData(data, funcsArr); 196 197 // FFX Step 2 198 ElasticNetLinearRegression.RunElasticNetLinearRegression(elnetData, approach.ElnetPenalty, out var _, out var _, out var _, out var candidateCoeffs, out var intercept, maxVars: approach.MaxNumBases); 199 200 // create models out of the learned coefficients 201 var ffxModels = GetUniqueModelsFromCoeffs(candidateCoeffs, intercept, funcsArr, approach); 202 203 // one last LS-optimization step on the training data 204 foreach (var ffxModel in ffxModels) { 205 //if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data); 206 } 207 return ffxModels; 208 } 209 210 // finds all models with unique combinations of basis functions 211 private static IEnumerable<FFXModel> GetUniqueModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) { 212 List<FFXModel> ffxModels = new List<FFXModel>(); 213 List<int[]> bfCombinations = new List<int[]>(); 214 215 for (int i = 0; i < intercept.Length; i++) { 216 var row = candidateCoeffs.GetRow(i); 217 var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray(); 218 if (nonzeroIndices.Count() > approach.MaxNumBases) continue; 219 // ignore duplicate models (models with same combination of basis functions) 220 if (bfCombinations.Any(arr => Enumerable.SequenceEqual(arr, nonzeroIndices))) continue; 221 var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx]))); 222 bfCombinations.Add(nonzeroIndices); 223 ffxModels.Add(ffxModel); 224 } 225 return ffxModels; 226 } 227 228 private static IEnumerable<Approach> CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) { 229 var approaches = new List<Approach>(); 230 var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions }; 231 232 // return true if ALL indices of true values of arr1 also have true values in arr2 233 bool follows(bool[] arr1, bool[] arr2) { 234 if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths"); 235 for (int i = 0; i < arr1.Length; i++) { 236 if (arr1[i] && !arr2[i]) return false; 237 } 238 return true; 239 } 240 241 for (int i = 0; i < 32; i++) { 242 // map i to a bool array of length 5 243 var arr = i.ToBoolArray(5); 244 if (!follows(arr, valids)) continue; 245 int sum = arr.Where(b => b).Count(); // how many features are enabled? 246 if (sum >= 4) continue; // not too many features at once 247 if (arr[0] && arr[2]) continue; // never need both exponent and inter 248 approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs)); 249 } 250 return approaches; 251 } 58 252 } 59 public IValueParameter<BoolValue> ConsiderDenominationsParameter60 {61 get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }62 }63 public IValueParameter<BoolValue> ConsiderExponentiationsParameter64 {65 get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }66 }67 public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter68 {69 get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }70 }71 public IValueParameter<BoolValue> ConsiderHingeFuncsParameter72 {73 get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }74 }75 public IValueParameter<DoubleValue> PenaltyParameter76 {77 get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }78 }79 public IValueParameter<DoubleValue> LambdaParameter80 {81 get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }82 }83 public IValueParameter<CheckedItemCollection<EnumValue<OpCode>>> NonlinearFuncsParameter84 {85 get { return (IValueParameter<CheckedItemCollection<EnumValue<OpCode>>>)Parameters[NonlinearFuncsParameterName]; }86 }87 #endregion88 89 #region properties90 public bool ConsiderInteractions91 {92 get { return ConsiderInteractionsParameter.Value.Value; }93 set { ConsiderInteractionsParameter.Value.Value = value; }94 }95 public bool ConsiderDenominations96 {97 get { return ConsiderDenominationsParameter.Value.Value; }98 set { ConsiderDenominationsParameter.Value.Value = value; }99 }100 public bool ConsiderExponentiations101 {102 get { return ConsiderExponentiationsParameter.Value.Value; }103 set { ConsiderExponentiationsParameter.Value.Value = value; }104 }105 public bool ConsiderNonlinearFuncs106 {107 get { return ConsiderNonlinearFuncsParameter.Value.Value; }108 set { ConsiderNonlinearFuncsParameter.Value.Value = value; }109 }110 public bool ConsiderHingeFuncs111 {112 get { return ConsiderHingeFuncsParameter.Value.Value; }113 set { ConsiderHingeFuncsParameter.Value.Value = value; }114 }115 public double Penalty116 {117 get { return PenaltyParameter.Value.Value; }118 set { PenaltyParameter.Value.Value = value; }119 }120 public DoubleValue Lambda121 {122 get { return LambdaParameter.Value; }123 set { LambdaParameter.Value = value; }124 }125 public CheckedItemCollection<EnumValue<OpCode>> NonlinearFuncs126 {127 get { return NonlinearFuncsParameter.Value; }128 set { NonlinearFuncsParameter.Value = value; }129 }130 #endregion131 132 133 [StorableConstructor]134 private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }135 public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {136 }137 public FastFunctionExtraction() : base() {138 var items = new CheckedItemCollection<EnumValue<OpCode>>();139 foreach (var op in nonlinFuncs) {140 items.Add(new EnumValue<OpCode>(op));141 }142 base.Problem = new RegressionProblem();143 Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));144 Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));145 Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));146 Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));147 Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));148 Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.9)));149 Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));150 Parameters.Add(new ValueParameter<CheckedItemCollection<EnumValue<OpCode>>>(NonlinearFuncsParameterName, "What nonlinear functions the models should be able to include.", items));151 }152 153 [StorableHook(HookType.AfterDeserialization)]154 private void AfterDeserialization() { }155 156 public override IDeepCloneable Clone(Cloner cloner) {157 return new FastFunctionExtraction(this, cloner);158 }159 160 public override Type ProblemType { get { return typeof(RegressionProblem); } }161 public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }162 163 public override bool SupportsPause { get { return true; } }164 165 protected override void Run(CancellationToken cancellationToken) {166 var basisFunctions = createBasisFunctions(Problem.ProblemData);167 Results.Add(new Result("Basis Functions", "A Dataset consisting of the generated Basis Functions from FFX Alg Step 1.", createProblemData(Problem.ProblemData, basisFunctions)));168 169 // add denominator bases to the already existing basis functions170 if (ConsiderDenominations) basisFunctions = basisFunctions.Concat(createDenominatorBases(Problem.ProblemData, basisFunctions)).ToList();171 172 // create either path of solutions, or one solution for given lambda173 LearnModels(Problem.ProblemData, basisFunctions);174 }175 176 private List<BasisFunction> createBasisFunctions(IRegressionProblemData problemData) {177 var basisFunctions = createUnivariateBases(problemData);178 basisFunctions = basisFunctions.Concat(createMultivariateBases(basisFunctions)).ToList();179 return basisFunctions;180 }181 182 private List<BasisFunction> createUnivariateBases(IRegressionProblemData problemData) {183 var B1 = new List<BasisFunction>();184 var inputVariables = problemData.AllowedInputVariables;185 var validExponents = ConsiderExponentiations ? exponents : new double[] { 1 };186 var validFuncs = NonlinearFuncs.CheckedItems.Select(val => val.Value);187 // TODO: add Hinge functions188 189 foreach (var variableName in inputVariables) {190 foreach (var exp in validExponents) {191 var data = problemData.Dataset.GetDoubleValues(variableName).Select(x => Math.Pow(x, exp)).ToArray();192 if (!ok(data)) continue;193 var name = expToString(exp, variableName);194 B1.Add(new BasisFunction(name, data, false));195 foreach (OpCode _op in validFuncs) {196 var inner_data = data.Select(x => eval(_op, x)).ToArray();197 if (!ok(inner_data)) continue;198 var inner_name = OpCodeToString.GetByFirst(_op) + "(" + name + ")";199 B1.Add(new BasisFunction(inner_name, inner_data, true));200 }201 }202 }203 return B1;204 }205 206 private List<BasisFunction> createMultivariateBases(List<BasisFunction> B1) {207 if (!ConsiderInteractions) return B1;208 var B2 = new List<BasisFunction>();209 for (int i = 0; i < B1.Count(); i++) {210 var b_i = B1.ElementAt(i);211 for (int j = 0; j < i; j++) {212 var b_j = B1.ElementAt(j);213 if (b_j.IsOperator) continue; // disallow op() * op()214 var b_inter = b_i * b_j;215 B2.Add(b_inter);216 }217 }218 219 return B2;220 // return union of B1 and B2221 }222 223 // creates 1 denominator basis function for each corresponding basis function from basisFunctions224 private IEnumerable<BasisFunction> createDenominatorBases(IRegressionProblemData problemData, IEnumerable<BasisFunction> basisFunctions) {225 var y = new BasisFunction(problemData.TargetVariable, problemData.TargetVariableValues.ToArray(), false);226 var denomBasisFuncs = new List<BasisFunction>();227 foreach (var func in basisFunctions) {228 var denomFunc = y * func;229 denomBasisFuncs.Add(denomFunc);230 }231 return denomBasisFuncs;232 }233 234 private static string expToString(double exponent, string varname) {235 if (exponent.IsAlmost(1)) return varname;236 if (exponent.IsAlmost(1 / 2)) return OpCodeToString.GetByFirst(OpCode.SquareRoot) + "(" + varname + ")";237 if (exponent.IsAlmost(1 / 3)) return OpCodeToString.GetByFirst(OpCode.CubeRoot) + "(" + varname + ")";238 if (exponent.IsAlmost(2)) return OpCodeToString.GetByFirst(OpCode.Square) + "(" + varname + ")";239 if (exponent.IsAlmost(3)) return OpCodeToString.GetByFirst(OpCode.Cube) + "(" + varname + ")";240 else return varname + " ^ " + exponent;241 }242 243 public static double eval(OpCode op, double x) {244 switch (op) {245 case OpCode.Absolute:246 return Math.Abs(x);247 case OpCode.Log:248 return Math.Log10(x);249 case OpCode.Sin:250 return Math.Sin(x);251 case OpCode.Cos:252 return Math.Cos(x);253 default:254 throw new Exception("Unimplemented operator: " + op.ToString());255 }256 }257 258 private void PathwiseLearning(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {259 ElasticNetLinearRegression reg = new ElasticNetLinearRegression();260 reg.Lambda = Lambda;261 reg.Penality = Penalty;262 reg.Problem.ProblemData = createProblemData(problemData, basisFunctions);263 reg.Start();264 Results.AddRange(reg.Results);265 }266 267 private void LearnModels(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {268 double[] lambda;269 double[] trainNMSE;270 double[] testNMSE;271 double[,] coeff;272 double[] intercept;273 int numNominatorBases = ConsiderDenominations ? basisFunctions.Count / 2 : basisFunctions.Count;274 275 // wraps the list of basis functions in a dataset, so that it can be passed on to the ElNet function276 var X_b = createProblemData(problemData, basisFunctions);277 278 ElasticNetLinearRegression.RunElasticNetLinearRegression(X_b, Penalty, out lambda, out trainNMSE, out testNMSE, out coeff, out intercept);279 280 var errorTable = NMSEGraph(coeff, lambda, trainNMSE, testNMSE);281 Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));282 var coeffTable = CoefficientGraph(coeff, lambda, X_b.AllowedInputVariables, X_b.Dataset);283 Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));284 285 ItemCollection<IResult> models = new ItemCollection<IResult>();286 for (int modelIdx = 0; modelIdx < coeff.GetUpperBound(0); modelIdx++) {287 var tree = Tree(basisFunctions, GetRow(coeff, modelIdx), intercept[modelIdx]);288 ISymbolicRegressionModel m = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());289 ISymbolicRegressionSolution s = new SymbolicRegressionSolution(m, Problem.ProblemData);290 models.Add(new Result("Solution " + modelIdx, s));291 }292 293 Results.Add(new Result("Models", "The model path returned by the Elastic Net Regression (not only the pareto-optimal subset). ", models));294 }295 296 private static IndexedDataTable<double> CoefficientGraph(double[,] coeff, double[] lambda, IEnumerable<string> allowedVars, IDataset ds) {297 var coeffTable = new IndexedDataTable<double>("Coefficients", "The paths of standarized coefficient values over different lambda values");298 coeffTable.VisualProperties.YAxisMaximumAuto = false;299 coeffTable.VisualProperties.YAxisMinimumAuto = false;300 coeffTable.VisualProperties.XAxisMaximumAuto = false;301 coeffTable.VisualProperties.XAxisMinimumAuto = false;302 303 coeffTable.VisualProperties.XAxisLogScale = true;304 coeffTable.VisualProperties.XAxisTitle = "Lambda";305 coeffTable.VisualProperties.YAxisTitle = "Coefficients";306 coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";307 308 var nLambdas = lambda.Length;309 var nCoeff = coeff.GetLength(1);310 var dataRows = new IndexedDataRow<double>[nCoeff];311 var numNonZeroCoeffs = new int[nLambdas];312 313 var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);314 var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);315 var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)316 {317 int i = 0;318 foreach (var factorVariableAndValues in factorVariablesAndValues) {319 foreach (var factorValue in factorVariableAndValues.Value) {320 double sigma = ds.GetStringValues(factorVariableAndValues.Key)321 .Select(s => s == factorValue ? 1.0 : 0.0)322 .StandardDeviation(); // calc std dev of binary indicator323 var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();324 dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);325 i++;326 }327 }328 329 foreach (var doubleVariable in doubleVariables) {330 double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();331 var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();332 dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);333 i++;334 }335 // add to coeffTable by total weight (larger area under the curve => more important);336 foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {337 coeffTable.Rows.Add(r);338 }339 }340 341 for (int i = 0; i < coeff.GetLength(0); i++) {342 for (int j = 0; j < coeff.GetLength(1); j++) {343 if (!coeff[i, j].IsAlmost(0.0)) {344 numNonZeroCoeffs[i]++;345 }346 }347 }348 if (lambda.Length > 2) {349 coeffTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));350 coeffTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));351 }352 coeffTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));353 coeffTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;354 coeffTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;355 356 return coeffTable;357 }358 359 private static IndexedDataTable<double> NMSEGraph(double[,] coeff, double[] lambda, double[] trainNMSE, double[] testNMSE) {360 var errorTable = new IndexedDataTable<double>("NMSE", "Path of NMSE values over different lambda values");361 var numNonZeroCoeffs = new int[lambda.Length];362 errorTable.VisualProperties.YAxisMaximumAuto = false;363 errorTable.VisualProperties.YAxisMinimumAuto = false;364 errorTable.VisualProperties.XAxisMaximumAuto = false;365 errorTable.VisualProperties.XAxisMinimumAuto = false;366 367 for (int i = 0; i < coeff.GetLength(0); i++) {368 for (int j = 0; j < coeff.GetLength(1); j++) {369 if (!coeff[i, j].IsAlmost(0.0)) {370 numNonZeroCoeffs[i]++;371 }372 }373 }374 375 errorTable.VisualProperties.YAxisMinimumFixedValue = 0;376 errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;377 errorTable.VisualProperties.XAxisLogScale = true;378 errorTable.VisualProperties.XAxisTitle = "Lambda";379 errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";380 errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";381 errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));382 errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));383 errorTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));384 if (lambda.Length > 2) {385 errorTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));386 errorTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));387 }388 errorTable.Rows["NMSE (train)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;389 errorTable.Rows["NMSE (test)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;390 errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;391 errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;392 393 return errorTable;394 }395 396 private ISymbolicExpressionTree Tree(List<BasisFunction> basisFunctions, double[] coeffs, double offset) {397 Debug.Assert(basisFunctions.Count() == coeffs.Length);398 //SymbolicExpressionTree399 var numNumeratorFuncs = ConsiderDenominations ? basisFunctions.Count() / 2 : basisFunctions.Count();400 var numeratorBasisFuncs = basisFunctions.Take(numNumeratorFuncs);401 402 // returns true if there exists at least 1 coefficient value in the model that is part of the denominator403 // (i.e. if there exists at least 1 non-zero value in the second half of the array)404 bool withDenom(double[] coeffarr) => coeffarr.Take(coeffarr.Length / 2).ToArray().Any(val => !val.IsAlmost(0.0));405 string model = "(" + offset.ToString();406 for (int i = 0; i < numNumeratorFuncs; i++) {407 var func = basisFunctions.ElementAt(i);408 // only generate nodes for relevant basis functions (those with non-zero coeffs)409 if (!coeffs[i].IsAlmost(0.0))410 model += " + (" + coeffs[i] + ") * " + func.Var;411 }412 if (ConsiderDenominations && withDenom(coeffs)) {413 model += ") / (1";414 for (int i = numNumeratorFuncs; i < basisFunctions.Count(); i++) {415 var func = basisFunctions.ElementAt(i);416 // only generate nodes for relevant basis functions (those with non-zero coeffs)417 if (!coeffs[i].IsAlmost(0.0))418 model += " + (" + coeffs[i] + ") * " + func.Var.Substring(4);419 }420 }421 model += ")";422 InfixExpressionParser p = new InfixExpressionParser();423 return p.Parse(model);424 }425 426 // wraps the list of basis functions into an IRegressionProblemData object427 private static IRegressionProblemData createProblemData(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {428 List<string> variableNames = new List<string>();429 List<IList> variableVals = new List<IList>();430 foreach (var basisFunc in basisFunctions) {431 variableNames.Add(basisFunc.Var);432 // basisFunctions already contains the calculated values of the corresponding basis function, so you can just take that value433 variableVals.Add(new List<double>(basisFunc.Val));434 }435 var matrix = new ModifiableDataset(variableNames, variableVals);436 437 // add the unmodified target variable to the matrix438 matrix.AddVariable(problemData.TargetVariable, problemData.TargetVariableValues.ToList());439 var allowedInputVars = matrix.VariableNames.Where(x => !x.Equals(problemData.TargetVariable));440 IRegressionProblemData rpd = new RegressionProblemData(matrix, allowedInputVars, problemData.TargetVariable);441 rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;442 rpd.TrainingPartition.End = problemData.TrainingPartition.End;443 rpd.TestPartition.Start = problemData.TestPartition.Start;444 rpd.TestPartition.End = problemData.TestPartition.End;445 return rpd;446 }447 448 private static bool ok(double[] data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));449 450 // helper function which returns a row of a 2D array451 private static T[] GetRow<T>(T[,] matrix, int row) {452 var columns = matrix.GetLength(1);453 var array = new T[columns];454 for (int i = 0; i < columns; ++i)455 array[i] = matrix[row, i];456 return array;457 }458 459 // returns all models with pareto-optimal tradeoff between error and complexity460 private static List<IRegressionSolution> nondominatedFilter(double[][] coefficientVectorSet, BasisFunction[] basisFunctions) {461 return null;462 }463 }464 253 }
Note: See TracChangeset
for help on using the changeset viewer.