Changeset 14564


Ignore:
Timestamp:
01/14/17 19:08:39 (2 years ago)
Author:
gkronber
Message:

#2657,#2677 merged r14258, r14316, r14319 and 14347.

Location:
stable
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearRegression.cs

    r14116 r14564  
    2121
    2222using System;
    23 using System.Collections.Generic;
    2423using System.Linq;
     24using HeuristicLab.Analysis;
    2525using HeuristicLab.Common;
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
     28using HeuristicLab.Optimization;
    2829using HeuristicLab.Parameters;
    29 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    30 using HeuristicLab.Optimization;
    3130using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3231using HeuristicLab.Problems.DataAnalysis;
    3332using HeuristicLab.Problems.DataAnalysis.Symbolic;
    3433using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
     34using HeuristicLab.Random;
    3535
    3636namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4545    private const string ModelStructureParameterName = "Model structure";
    4646    private const string IterationsParameterName = "Iterations";
     47    private const string RestartsParameterName = "Restarts";
     48    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     49    private const string SeedParameterName = "Seed";
     50    private const string InitParamsRandomlyParameterName = "InitializeParametersRandomly";
    4751
    4852    public IFixedValueParameter<StringValue> ModelStructureParameter {
     
    5155    public IFixedValueParameter<IntValue> IterationsParameter {
    5256      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
     57    }
     58
     59    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
     60      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     61    }
     62
     63    public IFixedValueParameter<IntValue> SeedParameter {
     64      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     65    }
     66
     67    public IFixedValueParameter<IntValue> RestartsParameter {
     68      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
     69    }
     70
     71    public IFixedValueParameter<BoolValue> InitParametersRandomlyParameter {
     72      get { return (IFixedValueParameter<BoolValue>)Parameters[InitParamsRandomlyParameterName]; }
    5373    }
    5474
     
    6383    }
    6484
     85    public int Restarts {
     86      get { return RestartsParameter.Value.Value; }
     87      set { RestartsParameter.Value.Value = value; }
     88    }
     89
     90    public int Seed {
     91      get { return SeedParameter.Value.Value; }
     92      set { SeedParameter.Value.Value = value; }
     93    }
     94
     95    public bool SetSeedRandomly {
     96      get { return SetSeedRandomlyParameter.Value.Value; }
     97      set { SetSeedRandomlyParameter.Value.Value = value; }
     98    }
     99
     100    public bool InitializeParametersRandomly {
     101      get { return InitParametersRandomlyParameter.Value.Value; }
     102      set { InitParametersRandomlyParameter.Value.Value = value; }
     103    }
    65104
    66105    [StorableConstructor]
     
    74113      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
    75114      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
    76     }
     115      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts (>0)", new IntValue(10)));
     116      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     117      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
     118      Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the real-valued model parameters should be initialized randomly in each restart.", new BoolValue(false)));
     119
     120      SetParameterHiddenState();
     121
     122      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
     123        SetParameterHiddenState();
     124      };
     125    }
     126
     127    private void SetParameterHiddenState() {
     128      var hide = !InitializeParametersRandomly;
     129      RestartsParameter.Hidden = hide;
     130      SeedParameter.Hidden = hide;
     131      SetSeedRandomlyParameter.Hidden = hide;
     132    }
     133
    77134    [StorableHook(HookType.AfterDeserialization)]
    78     private void AfterDeserialization() { }
     135    private void AfterDeserialization() {
     136      // BackwardsCompatibility3.3
     137      #region Backwards compatible code, remove with 3.4
     138      if (!Parameters.ContainsKey(RestartsParameterName))
     139        Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts", new IntValue(1)));
     140      if (!Parameters.ContainsKey(SeedParameterName))
     141        Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
     142      if (!Parameters.ContainsKey(SetSeedRandomlyParameterName))
     143        Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
     144      if (!Parameters.ContainsKey(InitParamsRandomlyParameterName))
     145        Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the numeric parameters of the model should be initialized randomly.", new BoolValue(false)));
     146
     147      SetParameterHiddenState();
     148      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
     149        SetParameterHiddenState();
     150      };
     151      #endregion
     152    }
    79153
    80154    public override IDeepCloneable Clone(Cloner cloner) {
     
    84158    #region nonlinear regression
    85159    protected override void Run() {
    86       var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
    87       Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", solution));
    88       Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(solution.TrainingRootMeanSquaredError)));
    89       Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(solution.TestRootMeanSquaredError)));
    90     }
    91 
    92     public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations) {
     160      IRegressionSolution bestSolution = null;
     161      if (InitializeParametersRandomly) {
     162        var qualityTable = new DataTable("RMSE table");
     163        qualityTable.VisualProperties.YAxisLogScale = true;
     164        var trainRMSERow = new DataRow("RMSE (train)");
     165        trainRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     166        var testRMSERow = new DataRow("RMSE test");
     167        testRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
     168
     169        qualityTable.Rows.Add(trainRMSERow);
     170        qualityTable.Rows.Add(testRMSERow);
     171        Results.Add(new Result(qualityTable.Name, qualityTable.Name + " for all restarts", qualityTable));
     172        if (SetSeedRandomly) Seed = (new System.Random()).Next();
     173        var rand = new MersenneTwister((uint)Seed);
     174        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     175        trainRMSERow.Values.Add(bestSolution.TrainingRootMeanSquaredError);
     176        testRMSERow.Values.Add(bestSolution.TestRootMeanSquaredError);
     177        for (int r = 0; r < Restarts; r++) {
     178          var solution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations, rand);
     179          trainRMSERow.Values.Add(solution.TrainingRootMeanSquaredError);
     180          testRMSERow.Values.Add(solution.TestRootMeanSquaredError);
     181          if (solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
     182            bestSolution = solution;
     183          }
     184        }
     185      } else {
     186        bestSolution = CreateRegressionSolution(Problem.ProblemData, ModelStructure, Iterations);
     187      }
     188
     189      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
     190      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
     191      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
     192
     193    }
     194
     195    /// <summary>
     196    /// Fits a model to the data by optimizing the numeric constants.
     197    /// Model is specified as infix expression containing variable names and numbers.
     198    /// The starting point for the numeric constants is initialized randomly if a random number generator is specified (~N(0,1)). Otherwise the user specified constants are
     199    /// used as a starting point.
     200    /// </summary>-
     201    /// <param name="problemData">Training and test data</param>
     202    /// <param name="modelStructure">The function as infix expression</param>
     203    /// <param name="maxIterations">Number of constant optimization iterations (using Levenberg-Marquardt algorithm)</param>
     204    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
     205    /// <returns></returns>
     206    public static ISymbolicRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData, string modelStructure, int maxIterations, IRandom rand = null) {
    93207      var parser = new InfixExpressionParser();
    94208      var tree = parser.Parse(modelStructure);
    95       var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
    96      
     209
    97210      if (!SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)) throw new ArgumentException("The optimizer does not support the specified model structure.");
    98211
     212      // initialize constants randomly
     213      if (rand != null) {
     214        foreach (var node in tree.IterateNodesPrefix().OfType<ConstantTreeNode>()) {
     215          double f = Math.Exp(NormalDistributedRandom.NextDouble(rand, 0, 1));
     216          double s = rand.NextDouble() < 0.5 ? -1 : 1;
     217          node.Value = s * node.Value * f;
     218        }
     219      }
    99220      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    100       SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
     221
     222      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
    101223        applyLinearScaling: false, maxIterations: maxIterations,
    102224        updateVariableWeights: false, updateConstantsInTree: true);
    103 
    104225
    105226      var scaledModel = new SymbolicRegressionModel(problemData.TargetVariable, tree, (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Formatters/InfixExpressionFormatter.cs

    r14116 r14564  
    8181          }
    8282          strBuilder.Append(")");
     83        } else {
     84          // function with multiple arguments
     85          strBuilder.Append(token).Append("(");
     86          FormatRecursively(node.Subtrees.First(), strBuilder);
     87          foreach (var subtree in node.Subtrees.Skip(1)) {
     88            strBuilder.Append(", ");
     89            FormatRecursively(subtree, strBuilder);
     90          }
     91          strBuilder.Append(")");
    8392        }
    8493      } else if (node.SubtreeCount == 1) {
     
    94103          FormatRecursively(node.GetSubtree(0), strBuilder);
    95104        } else {
    96           // function
     105          // function with only one argument
    97106          strBuilder.Append(token).Append("(");
    98107          FormatRecursively(node.GetSubtree(0), strBuilder);
  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Importer/InfixExpressionParser.cs

    r14116 r14564  
    3737  /// </summary>
    3838  public sealed class InfixExpressionParser {
    39     private enum TokenType { Operator, Identifier, Number, LeftPar, RightPar, End, NA };
     39    private enum TokenType { Operator, Identifier, Number, LeftPar, RightPar, Comma, End, NA };
    4040    private class Token {
    4141      internal double doubleVal;
     
    102102        { "MEAN", new Average()},
    103103        { "IF", new IfThenElse()},
    104         { ">", new GreaterThan()},
    105         { "<", new LessThan()},
     104        { "GT", new GreaterThan()},
     105        { "LT", new LessThan()},
    106106        { "AND", new And()},
    107107        { "OR", new Or()},
     
    138138        }
    139139        if (char.IsDigit(str[pos])) {
    140           // read number (=> read until white space or operator)
     140          // read number (=> read until white space or operator or comma)
    141141          var sb = new StringBuilder();
    142142          sb.Append(str[pos]);
    143143          pos++;
    144144          while (pos < str.Length && !char.IsWhiteSpace(str[pos])
    145             && (str[pos] != '+' || str[pos-1] == 'e' || str[pos-1] == 'E')     // continue reading exponents
     145            && (str[pos] != '+' || str[pos - 1] == 'e' || str[pos - 1] == 'E')     // continue reading exponents
    146146            && (str[pos] != '-' || str[pos - 1] == 'e' || str[pos - 1] == 'E')
    147             && str[pos] != '*'           
     147            && str[pos] != '*'
    148148            && str[pos] != '/'
    149             && str[pos] != ')') {
     149            && str[pos] != ')'
     150            && str[pos] != ',') {
    150151            sb.Append(str[pos]);
    151152            pos++;
     
    211212          pos++;
    212213          yield return new Token { TokenType = TokenType.RightPar, strVal = ")" };
    213         }
    214       }
    215     }
    216 
    217     // S = Expr EOF
    218     // Expr = ['-' | '+'] Term { '+' Term | '-' Term }
    219     // Term = Fact { '*' Fact | '/' Fact }
    220     // Fact = '(' Expr ')' | funcId '(' Expr ')' | varId | number
     214        } else if (str[pos] == ',') {
     215          pos++;
     216          yield return new Token { TokenType = TokenType.Comma, strVal = "," };
     217        } else {
     218          throw new ArgumentException("Invalid character: " + str[pos]);
     219        }
     220      }
     221    }
     222
     223    // S       = Expr EOF
     224    // Expr    = ['-' | '+'] Term { '+' Term | '-' Term }
     225    // Term    = Fact { '*' Fact | '/' Fact }
     226    // Fact    = '(' Expr ')' | funcId '(' ArgList ')' | varId | number
     227    // ArgList = Expr { ',' Expr }
    221228    private ISymbolicExpressionTreeNode ParseS(Queue<Token> tokens) {
    222229      var expr = ParseExpr(tokens);
     
    326333    }
    327334
    328     // Fact = '(' Expr ')' | funcId '(' Expr ')' | varId | number
     335    // Fact = '(' Expr ')' | funcId '(' ArgList ')' | varId | number
    329336    private ISymbolicExpressionTreeNode ParseFact(Queue<Token> tokens) {
    330337      var next = tokens.Peek();
     
    346353          if (lPar.TokenType != TokenType.LeftPar)
    347354            throw new ArgumentException("expected (");
    348           var expr = ParseExpr(tokens);
     355          var args = ParseArgList(tokens);
     356
     357          // check semantic constraints
     358          if (funcNode.Symbol.MinimumArity > args.Length || funcNode.Symbol.MaximumArity < args.Length)
     359            throw new ArgumentException(string.Format("Symbol {0} requires between {1} and  {2} arguments.", funcId,
     360              funcNode.Symbol.MinimumArity, funcNode.Symbol.MaximumArity));
     361          foreach (var arg in args) funcNode.AddSubtree(arg);
     362
    349363          var rPar = tokens.Dequeue();
    350364          if (rPar.TokenType != TokenType.RightPar)
    351365            throw new ArgumentException("expected )");
    352366
    353           funcNode.AddSubtree(expr);
    354367          return funcNode;
    355368        } else {
     
    369382      }
    370383    }
     384
     385    // ArgList = Expr { ',' Expr }
     386    private ISymbolicExpressionTreeNode[] ParseArgList(Queue<Token> tokens) {
     387      var exprList = new List<ISymbolicExpressionTreeNode>();
     388      exprList.Add(ParseExpr(tokens));
     389      while (tokens.Peek().TokenType != TokenType.RightPar) {
     390        var comma = tokens.Dequeue();
     391        if (comma.TokenType != TokenType.Comma) throw new ArgumentException("expected ',' ");
     392        exprList.Add(ParseExpr(tokens));
     393      }
     394      return exprList.ToArray();
     395    }
    371396  }
    372397}
  • stable/HeuristicLab.Tests

  • stable/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/InfixExpressionParserTest.cs

    r14116 r14564  
    8686      Console.WriteLine(formatter.Format(parser.Parse("x1*x2+x3*x4")));
    8787
     88
     89      Console.WriteLine(formatter.Format(parser.Parse("POW(3, 2)")));
     90      Console.WriteLine(formatter.Format(parser.Parse("POW(3.1, 2.1)")));
     91      Console.WriteLine(formatter.Format(parser.Parse("POW(3.1 , 2.1)")));
     92      Console.WriteLine(formatter.Format(parser.Parse("POW(3.1 ,2.1)")));
     93      Console.WriteLine(formatter.Format(parser.Parse("POW(-3.1 , - 2.1)")));
     94      Console.WriteLine(formatter.Format(parser.Parse("ROOT(3, 2)")));
     95      Console.WriteLine(formatter.Format(parser.Parse("ROOT(3.1, 2.1)")));
     96      Console.WriteLine(formatter.Format(parser.Parse("ROOT(3.1 , 2.1)")));
     97      Console.WriteLine(formatter.Format(parser.Parse("ROOT(3.1 ,2.1)")));
     98      Console.WriteLine(formatter.Format(parser.Parse("ROOT(-3.1 , - 2.1)")));
     99
     100      Console.WriteLine(formatter.Format(parser.Parse("IF(GT( 0, 1), 1, 0)")));
     101      Console.WriteLine(formatter.Format(parser.Parse("IF(LT(0,1), 1 , 0)")));
     102
    88103    }
    89104  }
Note: See TracChangeset for help on using the changeset viewer.