[17218] | 1 | using System;
|
---|
| 2 | using System.Threading;
|
---|
[17219] | 3 | using System.Linq;
|
---|
[17218] | 4 | using HeuristicLab.Common; // required for parameters collection
|
---|
| 5 | using HeuristicLab.Core; // required for parameters collection
|
---|
| 6 | using HeuristicLab.Data; // IntValue, ...
|
---|
| 7 | using HeuristicLab.Encodings.BinaryVectorEncoding;
|
---|
| 8 | using HeuristicLab.Optimization; // BasicAlgorithm
|
---|
| 9 | using HeuristicLab.Parameters;
|
---|
| 10 | using HeuristicLab.Problems.Binary;
|
---|
| 11 | using HeuristicLab.Random; // MersenneTwister
|
---|
| 12 | using HEAL.Attic;
|
---|
[17219] | 13 | using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
|
---|
| 14 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 15 | using System.Collections.Generic;
|
---|
| 16 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
[17227] | 17 | using System.Collections;
|
---|
| 18 | using System.Diagnostics;
|
---|
| 19 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 20 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
| 21 | using HeuristicLab.Analysis;
|
---|
| 22 | using HeuristicLab.Collections;
|
---|
[17218] | 23 |
|
---|
[17227] | 24 | namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
|
---|
[17218] | 25 |
|
---|
[17227] | 26 | [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
|
---|
| 27 | [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
|
---|
| 28 | [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
|
---|
| 29 | public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
|
---|
| 30 |
|
---|
| 31 | private static readonly double[] exponents = { 0.5, 1, 2 };
|
---|
| 32 | private static readonly OpCode[] nonlinFuncs = { OpCode.Absolute, OpCode.Log, OpCode.Sin, OpCode.Cos };
|
---|
| 33 |
|
---|
| 34 | private static readonly BidirectionalDictionary<OpCode, string> OpCodeToString = new BidirectionalDictionary<OpCode, string> {
|
---|
| 35 | { OpCode.Log, "LOG" },
|
---|
| 36 | { OpCode.Absolute, "ABS"},
|
---|
| 37 | { OpCode.Sin, "SIN"},
|
---|
| 38 | { OpCode.Cos, "COS"},
|
---|
| 39 | { OpCode.Square, "SQR"},
|
---|
| 40 | { OpCode.SquareRoot, "SQRT"},
|
---|
| 41 | { OpCode.Cube, "CUBE"},
|
---|
| 42 | { OpCode.CubeRoot, "CUBEROOT"}
|
---|
| 43 | };
|
---|
| 44 |
|
---|
| 45 | private const string ConsiderInteractionsParameterName = "Consider Interactions";
|
---|
| 46 | private const string ConsiderDenominationParameterName = "Consider Denomination";
|
---|
| 47 | private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
|
---|
| 48 | private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
|
---|
| 49 | private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
|
---|
| 50 | private const string PenaltyParameterName = "Penalty";
|
---|
| 51 | private const string LambdaParameterName = "Lambda";
|
---|
| 52 | private const string NonlinearFuncsParameterName = "Nonlinear Functions";
|
---|
| 53 |
|
---|
| 54 | #region parameters
|
---|
| 55 | public IValueParameter<BoolValue> ConsiderInteractionsParameter
|
---|
[17219] | 56 | {
|
---|
[17227] | 57 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
|
---|
| 58 | }
|
---|
| 59 | public IValueParameter<BoolValue> ConsiderDenominationsParameter
|
---|
| 60 | {
|
---|
| 61 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
|
---|
| 62 | }
|
---|
| 63 | public IValueParameter<BoolValue> ConsiderExponentiationsParameter
|
---|
| 64 | {
|
---|
| 65 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
|
---|
| 66 | }
|
---|
| 67 | public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter
|
---|
| 68 | {
|
---|
| 69 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
|
---|
| 70 | }
|
---|
| 71 | public IValueParameter<BoolValue> ConsiderHingeFuncsParameter
|
---|
| 72 | {
|
---|
| 73 | get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
|
---|
| 74 | }
|
---|
| 75 | public IValueParameter<DoubleValue> PenaltyParameter
|
---|
| 76 | {
|
---|
| 77 | get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
|
---|
| 78 | }
|
---|
| 79 | public IValueParameter<DoubleValue> LambdaParameter
|
---|
| 80 | {
|
---|
| 81 | get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
|
---|
| 82 | }
|
---|
| 83 | public IValueParameter<CheckedItemCollection<EnumValue<OpCode>>> NonlinearFuncsParameter
|
---|
| 84 | {
|
---|
| 85 | get { return (IValueParameter<CheckedItemCollection<EnumValue<OpCode>>>)Parameters[NonlinearFuncsParameterName]; }
|
---|
| 86 | }
|
---|
| 87 | #endregion
|
---|
[17218] | 88 |
|
---|
[17227] | 89 | #region properties
|
---|
| 90 | public bool ConsiderInteractions
|
---|
| 91 | {
|
---|
| 92 | get { return ConsiderInteractionsParameter.Value.Value; }
|
---|
| 93 | set { ConsiderInteractionsParameter.Value.Value = value; }
|
---|
| 94 | }
|
---|
| 95 | public bool ConsiderDenominations
|
---|
| 96 | {
|
---|
| 97 | get { return ConsiderDenominationsParameter.Value.Value; }
|
---|
| 98 | set { ConsiderDenominationsParameter.Value.Value = value; }
|
---|
| 99 | }
|
---|
| 100 | public bool ConsiderExponentiations
|
---|
| 101 | {
|
---|
| 102 | get { return ConsiderExponentiationsParameter.Value.Value; }
|
---|
| 103 | set { ConsiderExponentiationsParameter.Value.Value = value; }
|
---|
| 104 | }
|
---|
| 105 | public bool ConsiderNonlinearFuncs
|
---|
| 106 | {
|
---|
| 107 | get { return ConsiderNonlinearFuncsParameter.Value.Value; }
|
---|
| 108 | set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
|
---|
| 109 | }
|
---|
| 110 | public bool ConsiderHingeFuncs
|
---|
| 111 | {
|
---|
| 112 | get { return ConsiderHingeFuncsParameter.Value.Value; }
|
---|
| 113 | set { ConsiderHingeFuncsParameter.Value.Value = value; }
|
---|
| 114 | }
|
---|
| 115 | public double Penalty
|
---|
| 116 | {
|
---|
| 117 | get { return PenaltyParameter.Value.Value; }
|
---|
| 118 | set { PenaltyParameter.Value.Value = value; }
|
---|
| 119 | }
|
---|
| 120 | public DoubleValue Lambda
|
---|
| 121 | {
|
---|
| 122 | get { return LambdaParameter.Value; }
|
---|
| 123 | set { LambdaParameter.Value = value; }
|
---|
| 124 | }
|
---|
| 125 | public CheckedItemCollection<EnumValue<OpCode>> NonlinearFuncs
|
---|
| 126 | {
|
---|
| 127 | get { return NonlinearFuncsParameter.Value; }
|
---|
| 128 | set { NonlinearFuncsParameter.Value = value; }
|
---|
| 129 | }
|
---|
| 130 | #endregion
|
---|
[17218] | 131 |
|
---|
| 132 |
|
---|
[17227] | 133 | [StorableConstructor]
|
---|
| 134 | private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
|
---|
| 135 | public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
|
---|
| 136 | }
|
---|
| 137 | public FastFunctionExtraction() : base() {
|
---|
| 138 | var items = new CheckedItemCollection<EnumValue<OpCode>>();
|
---|
| 139 | foreach (var op in nonlinFuncs) {
|
---|
| 140 | items.Add(new EnumValue<OpCode>(op));
|
---|
| 141 | }
|
---|
| 142 | base.Problem = new RegressionProblem();
|
---|
| 143 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
|
---|
| 144 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
|
---|
| 145 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
|
---|
| 146 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
|
---|
| 147 | Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
|
---|
| 148 | Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.9)));
|
---|
| 149 | Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
|
---|
| 150 | Parameters.Add(new ValueParameter<CheckedItemCollection<EnumValue<OpCode>>>(NonlinearFuncsParameterName, "What nonlinear functions the models should be able to include.", items));
|
---|
| 151 | }
|
---|
| 152 |
|
---|
| 153 | [StorableHook(HookType.AfterDeserialization)]
|
---|
| 154 | private void AfterDeserialization() { }
|
---|
| 155 |
|
---|
| 156 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 157 | return new FastFunctionExtraction(this, cloner);
|
---|
| 158 | }
|
---|
| 159 |
|
---|
| 160 | public override Type ProblemType { get { return typeof(RegressionProblem); } }
|
---|
| 161 | public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
|
---|
| 162 |
|
---|
| 163 | public override bool SupportsPause { get { return true; } }
|
---|
| 164 |
|
---|
| 165 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 166 | var basisFunctions = createBasisFunctions(Problem.ProblemData);
|
---|
| 167 | Results.Add(new Result("Basis Functions", "A Dataset consisting of the generated Basis Functions from FFX Alg Step 1.", createProblemData(Problem.ProblemData, basisFunctions)));
|
---|
| 168 |
|
---|
| 169 | // add denominator bases to the already existing basis functions
|
---|
| 170 | if (ConsiderDenominations) basisFunctions = basisFunctions.Concat(createDenominatorBases(Problem.ProblemData, basisFunctions)).ToList();
|
---|
| 171 |
|
---|
| 172 | // create either path of solutions, or one solution for given lambda
|
---|
| 173 | LearnModels(Problem.ProblemData, basisFunctions);
|
---|
| 174 | }
|
---|
| 175 |
|
---|
| 176 | private List<BasisFunction> createBasisFunctions(IRegressionProblemData problemData) {
|
---|
| 177 | var basisFunctions = createUnivariateBases(problemData);
|
---|
| 178 | basisFunctions = basisFunctions.Concat(createMultivariateBases(basisFunctions)).ToList();
|
---|
| 179 | return basisFunctions;
|
---|
| 180 | }
|
---|
| 181 |
|
---|
| 182 | private List<BasisFunction> createUnivariateBases(IRegressionProblemData problemData) {
|
---|
| 183 | var B1 = new List<BasisFunction>();
|
---|
| 184 | var inputVariables = problemData.AllowedInputVariables;
|
---|
| 185 | var validExponents = ConsiderExponentiations ? exponents : new double[] { 1 };
|
---|
| 186 | var validFuncs = NonlinearFuncs.CheckedItems.Select(val => val.Value);
|
---|
| 187 | // TODO: add Hinge functions
|
---|
| 188 |
|
---|
| 189 | foreach (var variableName in inputVariables) {
|
---|
| 190 | foreach (var exp in validExponents) {
|
---|
| 191 | var data = problemData.Dataset.GetDoubleValues(variableName).Select(x => Math.Pow(x, exp)).ToArray();
|
---|
| 192 | if (!ok(data)) continue;
|
---|
| 193 | var name = expToString(exp, variableName);
|
---|
| 194 | B1.Add(new BasisFunction(name, data, false));
|
---|
| 195 | foreach (OpCode _op in validFuncs) {
|
---|
| 196 | var inner_data = data.Select(x => eval(_op, x)).ToArray();
|
---|
| 197 | if (!ok(inner_data)) continue;
|
---|
| 198 | var inner_name = OpCodeToString.GetByFirst(_op) + "(" + name + ")";
|
---|
| 199 | B1.Add(new BasisFunction(inner_name, inner_data, true));
|
---|
| 200 | }
|
---|
[17219] | 201 | }
|
---|
[17227] | 202 | }
|
---|
| 203 | return B1;
|
---|
| 204 | }
|
---|
[17218] | 205 |
|
---|
[17227] | 206 | private List<BasisFunction> createMultivariateBases(List<BasisFunction> B1) {
|
---|
| 207 | if (!ConsiderInteractions) return B1;
|
---|
| 208 | var B2 = new List<BasisFunction>();
|
---|
| 209 | for (int i = 0; i < B1.Count(); i++) {
|
---|
| 210 | var b_i = B1.ElementAt(i);
|
---|
| 211 | for (int j = 0; j < i; j++) {
|
---|
| 212 | var b_j = B1.ElementAt(j);
|
---|
| 213 | if (b_j.IsOperator) continue; // disallow op() * op()
|
---|
| 214 | var b_inter = b_i * b_j;
|
---|
| 215 | B2.Add(b_inter);
|
---|
[17219] | 216 | }
|
---|
[17227] | 217 | }
|
---|
[17218] | 218 |
|
---|
[17227] | 219 | return B2;
|
---|
| 220 | // return union of B1 and B2
|
---|
| 221 | }
|
---|
[17218] | 222 |
|
---|
[17227] | 223 | // creates 1 denominator basis function for each corresponding basis function from basisFunctions
|
---|
| 224 | private IEnumerable<BasisFunction> createDenominatorBases(IRegressionProblemData problemData, IEnumerable<BasisFunction> basisFunctions) {
|
---|
| 225 | var y = new BasisFunction(problemData.TargetVariable, problemData.TargetVariableValues.ToArray(), false);
|
---|
| 226 | var denomBasisFuncs = new List<BasisFunction>();
|
---|
| 227 | foreach (var func in basisFunctions) {
|
---|
| 228 | var denomFunc = y * func;
|
---|
| 229 | denomBasisFuncs.Add(denomFunc);
|
---|
| 230 | }
|
---|
| 231 | return denomBasisFuncs;
|
---|
| 232 | }
|
---|
[17218] | 233 |
|
---|
[17227] | 234 | private static string expToString(double exponent, string varname) {
|
---|
| 235 | if (exponent.IsAlmost(1)) return varname;
|
---|
| 236 | if (exponent.IsAlmost(1 / 2)) return OpCodeToString.GetByFirst(OpCode.SquareRoot) + "(" + varname + ")";
|
---|
| 237 | if (exponent.IsAlmost(1 / 3)) return OpCodeToString.GetByFirst(OpCode.CubeRoot) + "(" + varname + ")";
|
---|
| 238 | if (exponent.IsAlmost(2)) return OpCodeToString.GetByFirst(OpCode.Square) + "(" + varname + ")";
|
---|
| 239 | if (exponent.IsAlmost(3)) return OpCodeToString.GetByFirst(OpCode.Cube) + "(" + varname + ")";
|
---|
| 240 | else return varname + " ^ " + exponent;
|
---|
| 241 | }
|
---|
[17218] | 242 |
|
---|
[17227] | 243 | public static double eval(OpCode op, double x) {
|
---|
| 244 | switch (op) {
|
---|
| 245 | case OpCode.Absolute:
|
---|
| 246 | return Math.Abs(x);
|
---|
| 247 | case OpCode.Log:
|
---|
| 248 | return Math.Log10(x);
|
---|
| 249 | case OpCode.Sin:
|
---|
| 250 | return Math.Sin(x);
|
---|
| 251 | case OpCode.Cos:
|
---|
| 252 | return Math.Cos(x);
|
---|
| 253 | default:
|
---|
| 254 | throw new Exception("Unimplemented operator: " + op.ToString());
|
---|
| 255 | }
|
---|
| 256 | }
|
---|
[17218] | 257 |
|
---|
[17227] | 258 | private void PathwiseLearning(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
|
---|
| 259 | ElasticNetLinearRegression reg = new ElasticNetLinearRegression();
|
---|
| 260 | reg.Lambda = Lambda;
|
---|
| 261 | reg.Penality = Penalty;
|
---|
| 262 | reg.Problem.ProblemData = createProblemData(problemData, basisFunctions);
|
---|
| 263 | reg.Start();
|
---|
| 264 | Results.AddRange(reg.Results);
|
---|
| 265 | }
|
---|
[17218] | 266 |
|
---|
[17227] | 267 | private void LearnModels(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
|
---|
| 268 | double[] lambda;
|
---|
| 269 | double[] trainNMSE;
|
---|
| 270 | double[] testNMSE;
|
---|
| 271 | double[,] coeff;
|
---|
| 272 | double[] intercept;
|
---|
| 273 | int numNominatorBases = ConsiderDenominations ? basisFunctions.Count / 2 : basisFunctions.Count;
|
---|
[17218] | 274 |
|
---|
[17227] | 275 | // wraps the list of basis functions in a dataset, so that it can be passed on to the ElNet function
|
---|
| 276 | var X_b = createProblemData(problemData, basisFunctions);
|
---|
[17218] | 277 |
|
---|
[17227] | 278 | ElasticNetLinearRegression.RunElasticNetLinearRegression(X_b, Penalty, out lambda, out trainNMSE, out testNMSE, out coeff, out intercept);
|
---|
[17218] | 279 |
|
---|
[17227] | 280 | var errorTable = NMSEGraph(coeff, lambda, trainNMSE, testNMSE);
|
---|
| 281 | Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));
|
---|
| 282 | var coeffTable = CoefficientGraph(coeff, lambda, X_b.AllowedInputVariables, X_b.Dataset);
|
---|
| 283 | Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));
|
---|
| 284 |
|
---|
| 285 | ItemCollection<IResult> models = new ItemCollection<IResult>();
|
---|
| 286 | for (int modelIdx = 0; modelIdx < coeff.GetUpperBound(0); modelIdx++) {
|
---|
| 287 | var tree = Tree(basisFunctions, GetRow(coeff, modelIdx), intercept[modelIdx]);
|
---|
| 288 | ISymbolicRegressionModel m = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
|
---|
| 289 | ISymbolicRegressionSolution s = new SymbolicRegressionSolution(m, Problem.ProblemData);
|
---|
| 290 | models.Add(new Result("Solution " + modelIdx, s));
|
---|
| 291 | }
|
---|
| 292 |
|
---|
| 293 | Results.Add(new Result("Models", "The model path returned by the Elastic Net Regression (not only the pareto-optimal subset). ", models));
|
---|
| 294 | }
|
---|
| 295 |
|
---|
| 296 | private static IndexedDataTable<double> CoefficientGraph(double[,] coeff, double[] lambda, IEnumerable<string> allowedVars, IDataset ds) {
|
---|
| 297 | var coeffTable = new IndexedDataTable<double>("Coefficients", "The paths of standarized coefficient values over different lambda values");
|
---|
| 298 | coeffTable.VisualProperties.YAxisMaximumAuto = false;
|
---|
| 299 | coeffTable.VisualProperties.YAxisMinimumAuto = false;
|
---|
| 300 | coeffTable.VisualProperties.XAxisMaximumAuto = false;
|
---|
| 301 | coeffTable.VisualProperties.XAxisMinimumAuto = false;
|
---|
| 302 |
|
---|
| 303 | coeffTable.VisualProperties.XAxisLogScale = true;
|
---|
| 304 | coeffTable.VisualProperties.XAxisTitle = "Lambda";
|
---|
| 305 | coeffTable.VisualProperties.YAxisTitle = "Coefficients";
|
---|
| 306 | coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";
|
---|
| 307 |
|
---|
| 308 | var nLambdas = lambda.Length;
|
---|
| 309 | var nCoeff = coeff.GetLength(1);
|
---|
| 310 | var dataRows = new IndexedDataRow<double>[nCoeff];
|
---|
| 311 | var numNonZeroCoeffs = new int[nLambdas];
|
---|
| 312 |
|
---|
| 313 | var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);
|
---|
| 314 | var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);
|
---|
| 315 | var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
|
---|
| 316 | {
|
---|
| 317 | int i = 0;
|
---|
| 318 | foreach (var factorVariableAndValues in factorVariablesAndValues) {
|
---|
| 319 | foreach (var factorValue in factorVariableAndValues.Value) {
|
---|
| 320 | double sigma = ds.GetStringValues(factorVariableAndValues.Key)
|
---|
| 321 | .Select(s => s == factorValue ? 1.0 : 0.0)
|
---|
| 322 | .StandardDeviation(); // calc std dev of binary indicator
|
---|
| 323 | var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
|
---|
| 324 | dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);
|
---|
| 325 | i++;
|
---|
| 326 | }
|
---|
[17219] | 327 | }
|
---|
[17218] | 328 |
|
---|
[17227] | 329 | foreach (var doubleVariable in doubleVariables) {
|
---|
| 330 | double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();
|
---|
| 331 | var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
|
---|
| 332 | dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);
|
---|
| 333 | i++;
|
---|
[17219] | 334 | }
|
---|
[17227] | 335 | // add to coeffTable by total weight (larger area under the curve => more important);
|
---|
| 336 | foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
|
---|
| 337 | coeffTable.Rows.Add(r);
|
---|
| 338 | }
|
---|
| 339 | }
|
---|
[17218] | 340 |
|
---|
[17227] | 341 | for (int i = 0; i < coeff.GetLength(0); i++) {
|
---|
| 342 | for (int j = 0; j < coeff.GetLength(1); j++) {
|
---|
| 343 | if (!coeff[i, j].IsAlmost(0.0)) {
|
---|
| 344 | numNonZeroCoeffs[i]++;
|
---|
| 345 | }
|
---|
[17219] | 346 | }
|
---|
[17227] | 347 | }
|
---|
| 348 | if (lambda.Length > 2) {
|
---|
| 349 | coeffTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
|
---|
| 350 | coeffTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
|
---|
| 351 | }
|
---|
| 352 | coeffTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
|
---|
| 353 | coeffTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
|
---|
| 354 | coeffTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
|
---|
[17218] | 355 |
|
---|
[17227] | 356 | return coeffTable;
|
---|
| 357 | }
|
---|
[17219] | 358 |
|
---|
[17227] | 359 | private static IndexedDataTable<double> NMSEGraph(double[,] coeff, double[] lambda, double[] trainNMSE, double[] testNMSE) {
|
---|
| 360 | var errorTable = new IndexedDataTable<double>("NMSE", "Path of NMSE values over different lambda values");
|
---|
| 361 | var numNonZeroCoeffs = new int[lambda.Length];
|
---|
| 362 | errorTable.VisualProperties.YAxisMaximumAuto = false;
|
---|
| 363 | errorTable.VisualProperties.YAxisMinimumAuto = false;
|
---|
| 364 | errorTable.VisualProperties.XAxisMaximumAuto = false;
|
---|
| 365 | errorTable.VisualProperties.XAxisMinimumAuto = false;
|
---|
[17219] | 366 |
|
---|
[17227] | 367 | for (int i = 0; i < coeff.GetLength(0); i++) {
|
---|
| 368 | for (int j = 0; j < coeff.GetLength(1); j++) {
|
---|
| 369 | if (!coeff[i, j].IsAlmost(0.0)) {
|
---|
| 370 | numNonZeroCoeffs[i]++;
|
---|
| 371 | }
|
---|
[17218] | 372 | }
|
---|
[17227] | 373 | }
|
---|
[17218] | 374 |
|
---|
[17227] | 375 | errorTable.VisualProperties.YAxisMinimumFixedValue = 0;
|
---|
| 376 | errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;
|
---|
| 377 | errorTable.VisualProperties.XAxisLogScale = true;
|
---|
| 378 | errorTable.VisualProperties.XAxisTitle = "Lambda";
|
---|
| 379 | errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";
|
---|
| 380 | errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";
|
---|
| 381 | errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));
|
---|
| 382 | errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));
|
---|
| 383 | errorTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
|
---|
| 384 | if (lambda.Length > 2) {
|
---|
| 385 | errorTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
|
---|
| 386 | errorTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
|
---|
| 387 | }
|
---|
| 388 | errorTable.Rows["NMSE (train)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
|
---|
| 389 | errorTable.Rows["NMSE (test)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
|
---|
| 390 | errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
|
---|
| 391 | errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
|
---|
[17219] | 392 |
|
---|
[17227] | 393 | return errorTable;
|
---|
| 394 | }
|
---|
[17219] | 395 |
|
---|
[17227] | 396 | private ISymbolicExpressionTree Tree(List<BasisFunction> basisFunctions, double[] coeffs, double offset) {
|
---|
| 397 | Debug.Assert(basisFunctions.Count() == coeffs.Length);
|
---|
| 398 | //SymbolicExpressionTree
|
---|
| 399 | var numNumeratorFuncs = ConsiderDenominations ? basisFunctions.Count() / 2 : basisFunctions.Count();
|
---|
| 400 | var numeratorBasisFuncs = basisFunctions.Take(numNumeratorFuncs);
|
---|
[17219] | 401 |
|
---|
[17227] | 402 | // returns true if there exists at least 1 coefficient value in the model that is part of the denominator
|
---|
| 403 | // (i.e. if there exists at least 1 non-zero value in the second half of the array)
|
---|
| 404 | bool withDenom(double[] coeffarr) => coeffarr.Take(coeffarr.Length / 2).ToArray().Any(val => !val.IsAlmost(0.0));
|
---|
| 405 | string model = "(" + offset.ToString();
|
---|
| 406 | for (int i = 0; i < numNumeratorFuncs; i++) {
|
---|
| 407 | var func = basisFunctions.ElementAt(i);
|
---|
| 408 | // only generate nodes for relevant basis functions (those with non-zero coeffs)
|
---|
| 409 | if (!coeffs[i].IsAlmost(0.0))
|
---|
| 410 | model += " + (" + coeffs[i] + ") * " + func.Var;
|
---|
| 411 | }
|
---|
| 412 | if (ConsiderDenominations && withDenom(coeffs)) {
|
---|
| 413 | model += ") / (1";
|
---|
| 414 | for (int i = numNumeratorFuncs; i < basisFunctions.Count(); i++) {
|
---|
| 415 | var func = basisFunctions.ElementAt(i);
|
---|
| 416 | // only generate nodes for relevant basis functions (those with non-zero coeffs)
|
---|
| 417 | if (!coeffs[i].IsAlmost(0.0))
|
---|
| 418 | model += " + (" + coeffs[i] + ") * " + func.Var.Substring(4);
|
---|
[17219] | 419 | }
|
---|
[17227] | 420 | }
|
---|
| 421 | model += ")";
|
---|
| 422 | InfixExpressionParser p = new InfixExpressionParser();
|
---|
| 423 | return p.Parse(model);
|
---|
| 424 | }
|
---|
[17219] | 425 |
|
---|
[17227] | 426 | // wraps the list of basis functions into an IRegressionProblemData object
|
---|
| 427 | private static IRegressionProblemData createProblemData(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
|
---|
| 428 | List<string> variableNames = new List<string>();
|
---|
| 429 | List<IList> variableVals = new List<IList>();
|
---|
| 430 | foreach (var basisFunc in basisFunctions) {
|
---|
| 431 | variableNames.Add(basisFunc.Var);
|
---|
| 432 | // basisFunctions already contains the calculated values of the corresponding basis function, so you can just take that value
|
---|
| 433 | variableVals.Add(new List<double>(basisFunc.Val));
|
---|
| 434 | }
|
---|
| 435 | var matrix = new ModifiableDataset(variableNames, variableVals);
|
---|
[17219] | 436 |
|
---|
[17227] | 437 | // add the unmodified target variable to the matrix
|
---|
| 438 | matrix.AddVariable(problemData.TargetVariable, problemData.TargetVariableValues.ToList());
|
---|
| 439 | var allowedInputVars = matrix.VariableNames.Where(x => !x.Equals(problemData.TargetVariable));
|
---|
| 440 | IRegressionProblemData rpd = new RegressionProblemData(matrix, allowedInputVars, problemData.TargetVariable);
|
---|
| 441 | rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;
|
---|
| 442 | rpd.TrainingPartition.End = problemData.TrainingPartition.End;
|
---|
| 443 | rpd.TestPartition.Start = problemData.TestPartition.Start;
|
---|
| 444 | rpd.TestPartition.End = problemData.TestPartition.End;
|
---|
| 445 | return rpd;
|
---|
[17218] | 446 | }
|
---|
[17227] | 447 |
|
---|
| 448 | private static bool ok(double[] data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));
|
---|
| 449 |
|
---|
| 450 | // helper function which returns a row of a 2D array
|
---|
| 451 | private static T[] GetRow<T>(T[,] matrix, int row) {
|
---|
| 452 | var columns = matrix.GetLength(1);
|
---|
| 453 | var array = new T[columns];
|
---|
| 454 | for (int i = 0; i < columns; ++i)
|
---|
| 455 | array[i] = matrix[row, i];
|
---|
| 456 | return array;
|
---|
| 457 | }
|
---|
| 458 |
|
---|
| 459 | // returns all models with pareto-optimal tradeoff between error and complexity
|
---|
| 460 | private static List<IRegressionSolution> nondominatedFilter(double[][] coefficientVectorSet, BasisFunction[] basisFunctions) {
|
---|
| 461 | return null;
|
---|
| 462 | }
|
---|
| 463 | }
|
---|
[17219] | 464 | } |
---|