Changeset 17227


Ignore:
Timestamp:
09/02/19 16:30:38 (3 weeks ago)
Author:
lleko
Message:

#3022: Add implementation for FFX.

Location:
branches/3022-FastFunctionExtraction
Files:
2 added
4 edited

Legend:

Unmodified
Added
Removed
  • branches/3022-FastFunctionExtraction

    • Property svn:global-ignores set to
      FFX_Python
  • branches/3022-FastFunctionExtraction/FFX/BasisFunction.cs

    r17219 r17227  
    1 using System;
     1using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     2using System;
    23using System.Collections.Generic;
     4using System.Diagnostics;
    35using System.Linq;
    46using System.Text;
     
    79namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction
    810{
    9     public enum operators { Abs, Log, Sin, Cos };
    1011
    11     class BasisFunction
     12    struct BasisFunction
    1213    {
    13         public double _val { get; set; }
    14         public string _var { get; set; }
    15         public double _exp { get; set; }
    16         public operators _op { get; set; }
     14        public string Var { get; set; }     // e.g. "Abs(Column1 ** 2)"
     15        public double[] Val { get; set; }   // this holds the already calculated values, i.e. the function written in Var
     16        public bool IsOperator { get; set; }// alg needs to check if basis function has an operator
     17        public NonlinOp Operator { get; }
     18
     19        public BasisFunction(string var, double[] val, bool isOperator, NonlinOp op = NonlinOp.None)
     20        {
     21            this.Var = var;
     22            this.Val= val;
     23            this.IsOperator = isOperator;
     24            this.Operator = op;
     25        }
     26
     27        public static BasisFunction operator *(BasisFunction a, BasisFunction b)
     28        {
     29            Debug.Assert(a.Val.Length == b.Val.Length);
     30            double[] newVal = new double[a.Val.Length];
     31            for(int i = 0; i < a.Val.Length; i++)
     32            {
     33                newVal[i] = a.Val[i] * b.Val[i];
     34            }
     35            return new BasisFunction(a.Var + " * " + b.Var, newVal, false);
     36        }
     37
     38        public int Complexity() => 1;
     39
     40        public ISymbolicExpressionTree Tree()
     41        {
     42            return null;
     43        }
     44       
    1745    }
    1846}
  • branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs

    r17219 r17227  
    1515using System.Collections.Generic;
    1616using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    17 
    18 namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction
    19 {
    20 
    21     [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
    22     [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
    23     [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
    24     public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem>
    25     {
    26         private enum Operator { Abs, Log };
    27         private static readonly double[] exponents = { 0.5, 1, 2 };
    28 
    29         private const string PenaltyParameterName = "Penalty";
    30         private const string ConsiderInteractionsParameterName = "Consider Interactions";
    31         private const string ConsiderDenominationParameterName = "Consider Denomination";
    32         private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
    33         private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
    34         private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
    35 
    36         #region parameters
    37         public IValueParameter<BoolValue> ConsiderInteractionsParameter
    38         {
    39             get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
    40         }
    41         #endregion
    42 
    43         #region properties
    44         public bool ConsiderInteractions
    45         {
    46             get { return ConsiderInteractionsParameter.Value.Value; }
    47             set { ConsiderInteractionsParameter.Value.Value = value; }
    48         }
    49         #endregion
    50 
    51         [StorableConstructor]
    52         private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
    53         public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner)
    54         {
    55             // Don't forget to call the cloning ctor of the base class 
    56             // This class does not have fields, therefore we don't need to actually clone anything
    57         }
    58         public FastFunctionExtraction() : base()
    59         {
    60             // algorithm parameters are shown in the GUI
    61             Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
    62             Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want to consider interactions, otherwise false.", new BoolValue(true)));
    63             Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want to consider denominations, otherwise false.", new BoolValue(true)));
    64             Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want to consider exponentiation, otherwise false.", new BoolValue(true)));
    65             Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want to consider nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
    66             Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want to consider Hinge Functions, otherwise false.", new BoolValue(true)));
    67         }
    68 
    69         [StorableHook(HookType.AfterDeserialization)]
    70         private void AfterDeserialization() { }
    71 
    72         public override IDeepCloneable Clone(Cloner cloner)
    73         {
    74             return new FastFunctionExtraction(this, cloner);
    75         }
    76 
    77         public override Type ProblemType { get { return typeof(RegressionProblem); } }
    78         public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
    79 
    80 
    81         protected override void Run(CancellationToken cancellationToken)
    82         {
    83             var basisFunctions = generateBasisFunctions(Problem.ProblemData);
    84             var x = Problem.ProblemData.AllowedInputsTrainingValues;
    85             List<SymbolicExpressionTree> trees = new List<SymbolicExpressionTree>();
    86 
    87 
    88             foreach (var basisFunc in basisFunctions)
    89             {
    90                 // add tree representation of basisFunc to trees
    91                 trees.Add(generateSymbolicExpressionTree(basisFunc));
    92             }
    93 
    94             foreach (var tree in trees)
    95             {
    96                 // create new data through the help of the Interpreter
    97                 //IEnumerable<double> responses =
    98             }
    99 
    100             var coefficientVectorSet = findCoefficientValues(basisFunctions);
    101             var paretoFront = nondominatedFilter(coefficientVectorSet);
    102         }
    103 
    104         private SymbolicExpressionTree generateSymbolicExpressionTree(KeyValuePair<string, double[]> basisFunc)
    105         {
    106             throw new NotImplementedException();
    107         }
    108 
    109         // generate all possible models
    110         private static Dictionary<string, double[]> generateBasisFunctions(IRegressionProblemData problemData)
    111         {
    112             var basisFunctions = generateUnivariateBases(problemData);
    113             return basisFunctions;
    114         }
    115 
    116         private static Dictionary<string, double[]> generateUnivariateBases(IRegressionProblemData problemData)
    117         {
    118 
    119             var dataset = problemData.Dataset;
    120             var rows = problemData.TrainingIndices;
    121             var B1 = new Dictionary<string, double[]>();
    122 
    123             foreach (var variableName in dataset.VariableNames)
    124             {
    125                 foreach (var exp in new[] { 0.5, 1, 2 })
    126                 {
    127                     var name = variableName + " ** " + exp;
    128                     var data = dataset.GetDoubleValues(variableName, rows).Select(x => Math.Pow(x, exp)).ToArray();
    129                     B1.Add(name, data);
    130                     foreach (Operator op in Enum.GetValues(typeof(Operator)))
    131                     {
    132                         var inner_name = op.ToString() + "(" + name + ")";
    133                         var inner_data = data.Select(x => executeOperator(x, op)).ToArray();
    134                         B1.Add(inner_name, inner_data);
    135                     }
    136                 }
    137             }
    138            
    139             return B1;
    140         }
    141 
    142         private static double executeOperator(double x, Operator op)
    143         {
    144             switch (op)
    145             {
    146                 case Operator.Abs:
    147                     return x > 0 ? x : -x;
    148                 case Operator.Log:
    149                     return Math.Log10(x);
    150                 default:
    151                     throw new NotImplementedException();
    152             }
    153         }
    154 
    155         private static Dictionary<string, double[]> generateMultiVariateBases(Dictionary<string, double[]> B1)
    156         {
    157             var B2 = new Dictionary<string, double[]>();
    158             for(int i = 1; i <= B1.Count(); i++ )
    159             {
    160                 var b_i = B1.ElementAt(i);
    161                 for (int j = 1; j < i; i++)
    162                 {
    163                     var b_j = B1.ElementAt(j);
    164                 }
    165             }
    166 
    167             // return union of B1 and B2
    168             return B2.Concat(B1).ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
    169         }
    170 
    171         private static object findCoefficientValues(IEnumerable<KeyValuePair<string, double[]>> basisFunctions)
    172         {
    173             return new object();
    174         }
    175 
    176         private static object nondominatedFilter(object coefficientVectorSet)
    177         {
    178             return new object();
    179         }
    180 
    181         public override bool SupportsPause
    182         {
    183             get { return false; }
    184         }
    185     }
     17using System.Collections;
     18using System.Diagnostics;
     19using HeuristicLab.Problems.DataAnalysis.Symbolic;
     20using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
     21using HeuristicLab.Analysis;
     22using HeuristicLab.Collections;
     23
     24namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
     25
     26  [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
     27  [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
     28  [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
     29  public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem> {
     30
     31    private static readonly double[] exponents = { 0.5, 1, 2 };
     32    private static readonly OpCode[] nonlinFuncs = { OpCode.Absolute, OpCode.Log, OpCode.Sin, OpCode.Cos };
     33
     34    private static readonly BidirectionalDictionary<OpCode, string> OpCodeToString = new BidirectionalDictionary<OpCode, string> {
     35        { OpCode.Log, "LOG" },
     36        { OpCode.Absolute, "ABS"},
     37        { OpCode.Sin, "SIN"},
     38        { OpCode.Cos, "COS"},
     39        { OpCode.Square, "SQR"},
     40        { OpCode.SquareRoot, "SQRT"},
     41        { OpCode.Cube, "CUBE"},
     42        { OpCode.CubeRoot, "CUBEROOT"}
     43    };
     44
     45    private const string ConsiderInteractionsParameterName = "Consider Interactions";
     46    private const string ConsiderDenominationParameterName = "Consider Denomination";
     47    private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
     48    private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
     49    private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
     50    private const string PenaltyParameterName = "Penalty";
     51    private const string LambdaParameterName = "Lambda";
     52    private const string NonlinearFuncsParameterName = "Nonlinear Functions";
     53
     54    #region parameters
     55    public IValueParameter<BoolValue> ConsiderInteractionsParameter
     56    {
     57      get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
     58    }
     59    public IValueParameter<BoolValue> ConsiderDenominationsParameter
     60    {
     61      get { return (IValueParameter<BoolValue>)Parameters[ConsiderDenominationParameterName]; }
     62    }
     63    public IValueParameter<BoolValue> ConsiderExponentiationsParameter
     64    {
     65      get { return (IValueParameter<BoolValue>)Parameters[ConsiderExponentiationParameterName]; }
     66    }
     67    public IValueParameter<BoolValue> ConsiderNonlinearFuncsParameter
     68    {
     69      get { return (IValueParameter<BoolValue>)Parameters[ConsiderNonlinearFuncsParameterName]; }
     70    }
     71    public IValueParameter<BoolValue> ConsiderHingeFuncsParameter
     72    {
     73      get { return (IValueParameter<BoolValue>)Parameters[ConsiderHingeFuncsParameterName]; }
     74    }
     75    public IValueParameter<DoubleValue> PenaltyParameter
     76    {
     77      get { return (IValueParameter<DoubleValue>)Parameters[PenaltyParameterName]; }
     78    }
     79    public IValueParameter<DoubleValue> LambdaParameter
     80    {
     81      get { return (IValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
     82    }
     83    public IValueParameter<CheckedItemCollection<EnumValue<OpCode>>> NonlinearFuncsParameter
     84    {
     85      get { return (IValueParameter<CheckedItemCollection<EnumValue<OpCode>>>)Parameters[NonlinearFuncsParameterName]; }
     86    }
     87    #endregion
     88
     89    #region properties
     90    public bool ConsiderInteractions
     91    {
     92      get { return ConsiderInteractionsParameter.Value.Value; }
     93      set { ConsiderInteractionsParameter.Value.Value = value; }
     94    }
     95    public bool ConsiderDenominations
     96    {
     97      get { return ConsiderDenominationsParameter.Value.Value; }
     98      set { ConsiderDenominationsParameter.Value.Value = value; }
     99    }
     100    public bool ConsiderExponentiations
     101    {
     102      get { return ConsiderExponentiationsParameter.Value.Value; }
     103      set { ConsiderExponentiationsParameter.Value.Value = value; }
     104    }
     105    public bool ConsiderNonlinearFuncs
     106    {
     107      get { return ConsiderNonlinearFuncsParameter.Value.Value; }
     108      set { ConsiderNonlinearFuncsParameter.Value.Value = value; }
     109    }
     110    public bool ConsiderHingeFuncs
     111    {
     112      get { return ConsiderHingeFuncsParameter.Value.Value; }
     113      set { ConsiderHingeFuncsParameter.Value.Value = value; }
     114    }
     115    public double Penalty
     116    {
     117      get { return PenaltyParameter.Value.Value; }
     118      set { PenaltyParameter.Value.Value = value; }
     119    }
     120    public DoubleValue Lambda
     121    {
     122      get { return LambdaParameter.Value; }
     123      set { LambdaParameter.Value = value; }
     124    }
     125    public CheckedItemCollection<EnumValue<OpCode>> NonlinearFuncs
     126    {
     127      get { return NonlinearFuncsParameter.Value; }
     128      set { NonlinearFuncsParameter.Value = value; }
     129    }
     130    #endregion
     131
     132
     133    [StorableConstructor]
     134    private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
     135    public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
     136    }
     137    public FastFunctionExtraction() : base() {
     138      var items = new CheckedItemCollection<EnumValue<OpCode>>();
     139      foreach (var op in nonlinFuncs) {
     140        items.Add(new EnumValue<OpCode>(op));
     141      }
     142      base.Problem = new RegressionProblem();
     143      Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true)));
     144      Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true)));
     145      Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true)));
     146      Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
     147      Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true)));
     148      Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.9)));
     149      Parameters.Add(new OptionalValueParameter<DoubleValue>(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas"));
     150      Parameters.Add(new ValueParameter<CheckedItemCollection<EnumValue<OpCode>>>(NonlinearFuncsParameterName, "What nonlinear functions the models should be able to include.", items));
     151    }
     152
     153    [StorableHook(HookType.AfterDeserialization)]
     154    private void AfterDeserialization() { }
     155
     156    public override IDeepCloneable Clone(Cloner cloner) {
     157      return new FastFunctionExtraction(this, cloner);
     158    }
     159
     160    public override Type ProblemType { get { return typeof(RegressionProblem); } }
     161    public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
     162
     163    public override bool SupportsPause { get { return true; } }
     164
     165    protected override void Run(CancellationToken cancellationToken) {
     166      var basisFunctions = createBasisFunctions(Problem.ProblemData);
     167      Results.Add(new Result("Basis Functions", "A Dataset consisting of the generated Basis Functions from FFX Alg Step 1.", createProblemData(Problem.ProblemData, basisFunctions)));
     168
     169      // add denominator bases to the already existing basis functions
     170      if (ConsiderDenominations) basisFunctions = basisFunctions.Concat(createDenominatorBases(Problem.ProblemData, basisFunctions)).ToList();
     171
     172      // create either path of solutions, or one solution for given lambda
     173      LearnModels(Problem.ProblemData, basisFunctions);
     174    }
     175
     176    private List<BasisFunction> createBasisFunctions(IRegressionProblemData problemData) {
     177      var basisFunctions = createUnivariateBases(problemData);
     178      basisFunctions = basisFunctions.Concat(createMultivariateBases(basisFunctions)).ToList();
     179      return basisFunctions;
     180    }
     181
     182    private List<BasisFunction> createUnivariateBases(IRegressionProblemData problemData) {
     183      var B1 = new List<BasisFunction>();
     184      var inputVariables = problemData.AllowedInputVariables;
     185      var validExponents = ConsiderExponentiations ? exponents : new double[] { 1 };
     186      var validFuncs = NonlinearFuncs.CheckedItems.Select(val => val.Value);
     187      // TODO: add Hinge functions
     188
     189      foreach (var variableName in inputVariables) {
     190        foreach (var exp in validExponents) {
     191          var data = problemData.Dataset.GetDoubleValues(variableName).Select(x => Math.Pow(x, exp)).ToArray();
     192          if (!ok(data)) continue;
     193          var name = expToString(exp, variableName);
     194          B1.Add(new BasisFunction(name, data, false));
     195          foreach (OpCode _op in validFuncs) {
     196            var inner_data = data.Select(x => eval(_op, x)).ToArray();
     197            if (!ok(inner_data)) continue;
     198            var inner_name = OpCodeToString.GetByFirst(_op) + "(" + name + ")";
     199            B1.Add(new BasisFunction(inner_name, inner_data, true));
     200          }
     201        }
     202      }
     203      return B1;
     204    }
     205
     206    private List<BasisFunction> createMultivariateBases(List<BasisFunction> B1) {
     207      if (!ConsiderInteractions) return B1;
     208      var B2 = new List<BasisFunction>();
     209      for (int i = 0; i < B1.Count(); i++) {
     210        var b_i = B1.ElementAt(i);
     211        for (int j = 0; j < i; j++) {
     212          var b_j = B1.ElementAt(j);
     213          if (b_j.IsOperator) continue; // disallow op() * op()
     214          var b_inter = b_i * b_j;
     215          B2.Add(b_inter);
     216        }
     217      }
     218
     219      return B2;
     220      // return union of B1 and B2
     221    }
     222
     223    // creates 1 denominator basis function for each corresponding basis function from basisFunctions
     224    private IEnumerable<BasisFunction> createDenominatorBases(IRegressionProblemData problemData, IEnumerable<BasisFunction> basisFunctions) {
     225      var y = new BasisFunction(problemData.TargetVariable, problemData.TargetVariableValues.ToArray(), false);
     226      var denomBasisFuncs = new List<BasisFunction>();
     227      foreach (var func in basisFunctions) {
     228        var denomFunc = y * func;
     229        denomBasisFuncs.Add(denomFunc);
     230      }
     231      return denomBasisFuncs;
     232    }
     233
     234    private static string expToString(double exponent, string varname) {
     235      if (exponent.IsAlmost(1)) return varname;
     236      if (exponent.IsAlmost(1 / 2)) return OpCodeToString.GetByFirst(OpCode.SquareRoot) + "(" + varname + ")";
     237      if (exponent.IsAlmost(1 / 3)) return OpCodeToString.GetByFirst(OpCode.CubeRoot) + "(" + varname + ")";
     238      if (exponent.IsAlmost(2)) return OpCodeToString.GetByFirst(OpCode.Square) + "(" + varname + ")";
     239      if (exponent.IsAlmost(3)) return OpCodeToString.GetByFirst(OpCode.Cube) + "(" + varname + ")";
     240      else return varname + " ^ " + exponent;
     241    }
     242
     243    public static double eval(OpCode op, double x) {
     244      switch (op) {
     245        case OpCode.Absolute:
     246          return Math.Abs(x);
     247        case OpCode.Log:
     248          return Math.Log10(x);
     249        case OpCode.Sin:
     250          return Math.Sin(x);
     251        case OpCode.Cos:
     252          return Math.Cos(x);
     253        default:
     254          throw new Exception("Unimplemented operator: " + op.ToString());
     255      }
     256    }
     257
     258    private void PathwiseLearning(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
     259      ElasticNetLinearRegression reg = new ElasticNetLinearRegression();
     260      reg.Lambda = Lambda;
     261      reg.Penality = Penalty;
     262      reg.Problem.ProblemData = createProblemData(problemData, basisFunctions);
     263      reg.Start();
     264      Results.AddRange(reg.Results);
     265    }
     266
     267    private void LearnModels(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
     268      double[] lambda;
     269      double[] trainNMSE;
     270      double[] testNMSE;
     271      double[,] coeff;
     272      double[] intercept;
     273      int numNominatorBases = ConsiderDenominations ? basisFunctions.Count / 2 : basisFunctions.Count;
     274
     275      // wraps the list of basis functions in a dataset, so that it can be passed on to the ElNet function
     276      var X_b = createProblemData(problemData, basisFunctions);
     277
     278      ElasticNetLinearRegression.RunElasticNetLinearRegression(X_b, Penalty, out lambda, out trainNMSE, out testNMSE, out coeff, out intercept);
     279
     280      var errorTable = NMSEGraph(coeff, lambda, trainNMSE, testNMSE);
     281      Results.Add(new Result(errorTable.Name, errorTable.Description, errorTable));
     282      var coeffTable = CoefficientGraph(coeff, lambda, X_b.AllowedInputVariables, X_b.Dataset);
     283      Results.Add(new Result(coeffTable.Name, coeffTable.Description, coeffTable));
     284
     285      ItemCollection<IResult> models = new ItemCollection<IResult>();
     286      for (int modelIdx = 0; modelIdx < coeff.GetUpperBound(0); modelIdx++) {
     287        var tree = Tree(basisFunctions, GetRow(coeff, modelIdx), intercept[modelIdx]);
     288        ISymbolicRegressionModel m = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
     289        ISymbolicRegressionSolution s = new SymbolicRegressionSolution(m, Problem.ProblemData);
     290        models.Add(new Result("Solution " + modelIdx, s));
     291      }
     292
     293      Results.Add(new Result("Models", "The model path returned by the Elastic Net Regression (not only the pareto-optimal subset). ", models));
     294    }
     295
     296    private static IndexedDataTable<double> CoefficientGraph(double[,] coeff, double[] lambda, IEnumerable<string> allowedVars, IDataset ds) {
     297      var coeffTable = new IndexedDataTable<double>("Coefficients", "The paths of standarized coefficient values over different lambda values");
     298      coeffTable.VisualProperties.YAxisMaximumAuto = false;
     299      coeffTable.VisualProperties.YAxisMinimumAuto = false;
     300      coeffTable.VisualProperties.XAxisMaximumAuto = false;
     301      coeffTable.VisualProperties.XAxisMinimumAuto = false;
     302
     303      coeffTable.VisualProperties.XAxisLogScale = true;
     304      coeffTable.VisualProperties.XAxisTitle = "Lambda";
     305      coeffTable.VisualProperties.YAxisTitle = "Coefficients";
     306      coeffTable.VisualProperties.SecondYAxisTitle = "Number of variables";
     307
     308      var nLambdas = lambda.Length;
     309      var nCoeff = coeff.GetLength(1);
     310      var dataRows = new IndexedDataRow<double>[nCoeff];
     311      var numNonZeroCoeffs = new int[nLambdas];
     312
     313      var doubleVariables = allowedVars.Where(ds.VariableHasType<double>);
     314      var factorVariableNames = allowedVars.Where(ds.VariableHasType<string>);
     315      var factorVariablesAndValues = ds.GetFactorVariableValues(factorVariableNames, Enumerable.Range(0, ds.Rows)); // must consider all factor values (in train and test set)
     316      {
     317        int i = 0;
     318        foreach (var factorVariableAndValues in factorVariablesAndValues) {
     319          foreach (var factorValue in factorVariableAndValues.Value) {
     320            double sigma = ds.GetStringValues(factorVariableAndValues.Key)
     321              .Select(s => s == factorValue ? 1.0 : 0.0)
     322              .StandardDeviation(); // calc std dev of binary indicator
     323            var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
     324            dataRows[i] = new IndexedDataRow<double>(factorVariableAndValues.Key + "=" + factorValue, factorVariableAndValues.Key + "=" + factorValue, path);
     325            i++;
     326          }
     327        }
     328
     329        foreach (var doubleVariable in doubleVariables) {
     330          double sigma = ds.GetDoubleValues(doubleVariable).StandardDeviation();
     331          var path = Enumerable.Range(0, nLambdas).Select(r => Tuple.Create(lambda[r], coeff[r, i] * sigma)).ToArray();
     332          dataRows[i] = new IndexedDataRow<double>(doubleVariable, doubleVariable, path);
     333          i++;
     334        }
     335        // add to coeffTable by total weight (larger area under the curve => more important);
     336        foreach (var r in dataRows.OrderByDescending(r => r.Values.Select(t => t.Item2).Sum(x => Math.Abs(x)))) {
     337          coeffTable.Rows.Add(r);
     338        }
     339      }
     340
     341      for (int i = 0; i < coeff.GetLength(0); i++) {
     342        for (int j = 0; j < coeff.GetLength(1); j++) {
     343          if (!coeff[i, j].IsAlmost(0.0)) {
     344            numNonZeroCoeffs[i]++;
     345          }
     346        }
     347      }
     348      if (lambda.Length > 2) {
     349        coeffTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
     350        coeffTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
     351      }
     352      coeffTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
     353      coeffTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     354      coeffTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
     355
     356      return coeffTable;
     357    }
     358
     359    private static IndexedDataTable<double> NMSEGraph(double[,] coeff, double[] lambda, double[] trainNMSE, double[] testNMSE) {
     360      var errorTable = new IndexedDataTable<double>("NMSE", "Path of NMSE values over different lambda values");
     361      var numNonZeroCoeffs = new int[lambda.Length];
     362      errorTable.VisualProperties.YAxisMaximumAuto = false;
     363      errorTable.VisualProperties.YAxisMinimumAuto = false;
     364      errorTable.VisualProperties.XAxisMaximumAuto = false;
     365      errorTable.VisualProperties.XAxisMinimumAuto = false;
     366
     367      for (int i = 0; i < coeff.GetLength(0); i++) {
     368        for (int j = 0; j < coeff.GetLength(1); j++) {
     369          if (!coeff[i, j].IsAlmost(0.0)) {
     370            numNonZeroCoeffs[i]++;
     371          }
     372        }
     373      }
     374
     375      errorTable.VisualProperties.YAxisMinimumFixedValue = 0;
     376      errorTable.VisualProperties.YAxisMaximumFixedValue = 1.0;
     377      errorTable.VisualProperties.XAxisLogScale = true;
     378      errorTable.VisualProperties.XAxisTitle = "Lambda";
     379      errorTable.VisualProperties.YAxisTitle = "Normalized mean of squared errors (NMSE)";
     380      errorTable.VisualProperties.SecondYAxisTitle = "Number of variables";
     381      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (train)", "Path of NMSE values over different lambda values", lambda.Zip(trainNMSE, (l, v) => Tuple.Create(l, v))));
     382      errorTable.Rows.Add(new IndexedDataRow<double>("NMSE (test)", "Path of NMSE values over different lambda values", lambda.Zip(testNMSE, (l, v) => Tuple.Create(l, v))));
     383      errorTable.Rows.Add(new IndexedDataRow<double>("Number of variables", "The number of non-zero coefficients for each step in the path", lambda.Zip(numNonZeroCoeffs, (l, v) => Tuple.Create(l, (double)v))));
     384      if (lambda.Length > 2) {
     385        errorTable.VisualProperties.XAxisMinimumFixedValue = Math.Pow(10, Math.Floor(Math.Log10(lambda.Last())));
     386        errorTable.VisualProperties.XAxisMaximumFixedValue = Math.Pow(10, Math.Ceiling(Math.Log10(lambda.Skip(1).First())));
     387      }
     388      errorTable.Rows["NMSE (train)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     389      errorTable.Rows["NMSE (test)"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     390      errorTable.Rows["Number of variables"].VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     391      errorTable.Rows["Number of variables"].VisualProperties.SecondYAxis = true;
     392
     393      return errorTable;
     394    }
     395
     396    private ISymbolicExpressionTree Tree(List<BasisFunction> basisFunctions, double[] coeffs, double offset) {
     397      Debug.Assert(basisFunctions.Count() == coeffs.Length);
     398      //SymbolicExpressionTree
     399      var numNumeratorFuncs = ConsiderDenominations ? basisFunctions.Count() / 2 : basisFunctions.Count();
     400      var numeratorBasisFuncs = basisFunctions.Take(numNumeratorFuncs);
     401
     402      // returns true if there exists at least 1 coefficient value in the model that is part of the denominator
     403      // (i.e. if there exists at least 1 non-zero value in the second half of the array)
     404      bool withDenom(double[] coeffarr) => coeffarr.Take(coeffarr.Length / 2).ToArray().Any(val => !val.IsAlmost(0.0));
     405      string model = "(" + offset.ToString();
     406      for (int i = 0; i < numNumeratorFuncs; i++) {
     407        var func = basisFunctions.ElementAt(i);
     408        // only generate nodes for relevant basis functions (those with non-zero coeffs)
     409        if (!coeffs[i].IsAlmost(0.0))
     410          model += " + (" + coeffs[i] + ") * " + func.Var;
     411      }
     412      if (ConsiderDenominations && withDenom(coeffs)) {
     413        model += ") / (1";
     414        for (int i = numNumeratorFuncs; i < basisFunctions.Count(); i++) {
     415          var func = basisFunctions.ElementAt(i);
     416          // only generate nodes for relevant basis functions (those with non-zero coeffs)
     417          if (!coeffs[i].IsAlmost(0.0))
     418            model += " + (" + coeffs[i] + ") * " + func.Var.Substring(4);
     419        }
     420      }
     421      model += ")";
     422      InfixExpressionParser p = new InfixExpressionParser();
     423      return p.Parse(model);
     424    }
     425
     426    // wraps the list of basis functions into an IRegressionProblemData object
     427    private static IRegressionProblemData createProblemData(IRegressionProblemData problemData, List<BasisFunction> basisFunctions) {
     428      List<string> variableNames = new List<string>();
     429      List<IList> variableVals = new List<IList>();
     430      foreach (var basisFunc in basisFunctions) {
     431        variableNames.Add(basisFunc.Var);
     432        // basisFunctions already contains the calculated values of the corresponding basis function, so you can just take that value
     433        variableVals.Add(new List<double>(basisFunc.Val));
     434      }
     435      var matrix = new ModifiableDataset(variableNames, variableVals);
     436
     437      // add the unmodified target variable to the matrix
     438      matrix.AddVariable(problemData.TargetVariable, problemData.TargetVariableValues.ToList());
     439      var allowedInputVars = matrix.VariableNames.Where(x => !x.Equals(problemData.TargetVariable));
     440      IRegressionProblemData rpd = new RegressionProblemData(matrix, allowedInputVars, problemData.TargetVariable);
     441      rpd.TrainingPartition.Start = problemData.TrainingPartition.Start;
     442      rpd.TrainingPartition.End = problemData.TrainingPartition.End;
     443      rpd.TestPartition.Start = problemData.TestPartition.Start;
     444      rpd.TestPartition.End = problemData.TestPartition.End;
     445      return rpd;
     446    }
     447
     448    private static bool ok(double[] data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x));
     449
     450    // helper function which returns a row of a 2D array
     451    private static T[] GetRow<T>(T[,] matrix, int row) {
     452      var columns = matrix.GetLength(1);
     453      var array = new T[columns];
     454      for (int i = 0; i < columns; ++i)
     455        array[i] = matrix[row, i];
     456      return array;
     457    }
     458
     459    // returns all models with pareto-optimal tradeoff between error and complexity
     460    private static List<IRegressionSolution> nondominatedFilter(double[][] coefficientVectorSet, BasisFunction[] basisFunctions) {
     461      return null;
     462    }
     463  }
    186464}
  • branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.csproj

    r17219 r17227  
    4343      <SpecificVersion>False</SpecificVersion>
    4444      <HintPath>..\..\..\trunk\bin\HeuristicLab.Algorithms.DataAnalysis.Glmnet-3.4.dll</HintPath>
     45    </Reference>
     46    <Reference Include="HeuristicLab.Analysis-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     47      <SpecificVersion>False</SpecificVersion>
     48      <HintPath>..\..\..\trunk\bin\HeuristicLab.Analysis-3.3.dll</HintPath>
    4549    </Reference>
    4650    <Reference Include="HeuristicLab.Collections-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     
    9195      <HintPath>..\..\..\trunk\bin\HeuristicLab.Problems.DataAnalysis-3.4.dll</HintPath>
    9296    </Reference>
     97    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     98      <SpecificVersion>False</SpecificVersion>
     99      <HintPath>..\..\..\trunk\bin\HeuristicLab.Problems.DataAnalysis.Symbolic-3.4.dll</HintPath>
     100    </Reference>
    93101    <Reference Include="HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4, Version=3.4.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
    94102      <SpecificVersion>False</SpecificVersion>
     
    112120    <Compile Include="BasisFunction.cs" />
    113121    <Compile Include="FastFunctionExtraction.cs" />
    114     <Compile Include="GeneralizedLinearModel.cs" />
     122    <Compile Include="Operator.cs" />
    115123    <Compile Include="Plugin.cs" />
    116124    <Compile Include="Properties\AssemblyInfo.cs" />
Note: See TracChangeset for help on using the changeset viewer.