Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/05/16 21:34:18 (8 years ago)
Author:
mkommend
Message:

#2584: Merged r13670 and r13916 into stable.
#2609: Merged r13869 and r13900 into stable.

Location:
stable
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r13310 r14004  
    4040    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
    4141    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
     42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
    4243
    4344    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
     
    5657      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
    5758    }
     59    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
     60      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
     61    }
     62
    5863
    5964    public IntValue ConstantOptimizationIterations {
     
    7277      get { return UpdateConstantsInTreeParameter.Value.Value; }
    7378      set { UpdateConstantsInTreeParameter.Value.Value = value; }
     79    }
     80
     81    public bool UpdateVariableWeights {
     82      get { return UpdateVariableWeightsParameter.Value.Value; }
     83      set { UpdateVariableWeightsParameter.Value.Value = value; }
    7484    }
    7585
     
    8696      : base() {
    8797      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
    88       Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
     98      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
    8999      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
    90100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
    91       Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
     101      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
     102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
    92103    }
    93104
     
    100111      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
    101112        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
     113      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
     114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
    102115    }
    103116
     
    108121        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
    109122        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
    110            constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
    111            EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
     123           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
    112124
    113125        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
     
    164176
    165177
    166     // TODO: swap positions of lowerEstimationLimit and upperEstimationLimit parameters
    167     public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
    168       IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
     178    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
    169179
    170180      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
     
    173183
    174184      AutoDiff.Term func;
    175       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
     185      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out func))
    176186        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    177187      if (variableNames.Count == 0) return 0.0;
    178188
    179       AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
    180 
    181       List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
     189      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
     190
     191      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
     192      if (updateVariableWeights)
     193        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
     194      else
     195        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>());
     196
     197      //extract inital constants
    182198      double[] c = new double[variables.Count];
    183 
    184199      {
    185200        c[0] = 0.0;
    186201        c[1] = 1.0;
    187         //extract inital constants
    188202        int i = 2;
    189203        foreach (var node in terminalNodes) {
     
    192206          if (constantTreeNode != null)
    193207            c[i++] = constantTreeNode.Value;
    194           else if (variableTreeNode != null)
     208          else if (updateVariableWeights && variableTreeNode != null)
    195209            c[i++] = variableTreeNode.Weight;
    196210        }
     
    235249
    236250      //info == -7  => constant optimization failed due to wrong gradient
    237       if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
     251      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
    238252      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
    239253
    240       if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     254      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
    241255      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
    242         UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     256        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
    243257        return originalQuality;
    244258      }
     
    246260    }
    247261
    248     private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
     262    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
    249263      int i = 0;
    250264      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     
    253267        if (constantTreeNode != null)
    254268          constantTreeNode.Value = constants[i++];
    255         else if (variableTreeNode != null)
     269        else if (updateVariableWeights && variableTreeNode != null)
    256270          variableTreeNode.Weight = constants[i++];
    257271      }
     
    272286    }
    273287
    274     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
     288    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, bool updateVariableWeights, out AutoDiff.Term term) {
    275289      if (node.Symbol is Constant) {
    276290        var var = new AutoDiff.Variable();
     
    284298        parameters.Add(par);
    285299        variableNames.Add(varNode.VariableName);
    286         var w = new AutoDiff.Variable();
    287         variables.Add(w);
    288         term = AutoDiff.TermBuilder.Product(w, par);
     300
     301        if (updateVariableWeights) {
     302          var w = new AutoDiff.Variable();
     303          variables.Add(w);
     304          term = AutoDiff.TermBuilder.Product(w, par);
     305        } else {
     306          term = par;
     307        }
    289308        return true;
    290309      }
     
    293312        foreach (var subTree in node.Subtrees) {
    294313          AutoDiff.Term t;
    295           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
     314          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
    296315            term = null;
    297316            return false;
     
    306325        for (int i = 0; i < node.SubtreeCount; i++) {
    307326          AutoDiff.Term t;
    308           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
     327          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, updateVariableWeights, out t)) {
    309328            term = null;
    310329            return false;
     
    317336      }
    318337      if (node.Symbol is Multiplication) {
    319         AutoDiff.Term a, b;
    320         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    321           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    322           term = null;
    323           return false;
    324         } else {
    325           List<AutoDiff.Term> factors = new List<Term>();
    326           foreach (var subTree in node.Subtrees.Skip(2)) {
    327             AutoDiff.Term f;
    328             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    329               term = null;
    330               return false;
    331             }
    332             factors.Add(f);
     338        List<AutoDiff.Term> terms = new List<Term>();
     339        foreach (var subTree in node.Subtrees) {
     340          AutoDiff.Term t;
     341          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
     342            term = null;
     343            return false;
    333344          }
    334           term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
    335           return true;
    336         }
     345          terms.Add(t);
     346        }
     347        if (terms.Count == 1) term = terms[0];
     348        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
     349        return true;
     350
    337351      }
    338352      if (node.Symbol is Division) {
    339         // only works for at least two subtrees
    340         AutoDiff.Term a, b;
    341         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    342           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    343           term = null;
    344           return false;
    345         } else {
    346           List<AutoDiff.Term> factors = new List<Term>();
    347           foreach (var subTree in node.Subtrees.Skip(2)) {
    348             AutoDiff.Term f;
    349             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    350               term = null;
    351               return false;
    352             }
    353             factors.Add(1.0 / f);
     353        List<AutoDiff.Term> terms = new List<Term>();
     354        foreach (var subTree in node.Subtrees) {
     355          AutoDiff.Term t;
     356          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
     357            term = null;
     358            return false;
    354359          }
    355           term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
    356           return true;
    357         }
     360          terms.Add(t);
     361        }
     362        if (terms.Count == 1) term = 1.0 / terms[0];
     363        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
     364        return true;
    358365      }
    359366      if (node.Symbol is Logarithm) {
    360367        AutoDiff.Term t;
    361         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     368        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    362369          term = null;
    363370          return false;
     
    369376      if (node.Symbol is Exponential) {
    370377        AutoDiff.Term t;
    371         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     378        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    372379          term = null;
    373380          return false;
     
    379386      if (node.Symbol is Square) {
    380387        AutoDiff.Term t;
    381         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     388        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    382389          term = null;
    383390          return false;
     
    386393          return true;
    387394        }
    388       } if (node.Symbol is SquareRoot) {
    389         AutoDiff.Term t;
    390         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     395      }
     396      if (node.Symbol is SquareRoot) {
     397        AutoDiff.Term t;
     398        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    391399          term = null;
    392400          return false;
     
    395403          return true;
    396404        }
    397       } if (node.Symbol is Sine) {
    398         AutoDiff.Term t;
    399         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     405      }
     406      if (node.Symbol is Sine) {
     407        AutoDiff.Term t;
     408        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    400409          term = null;
    401410          return false;
     
    404413          return true;
    405414        }
    406       } if (node.Symbol is Cosine) {
    407         AutoDiff.Term t;
    408         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     415      }
     416      if (node.Symbol is Cosine) {
     417        AutoDiff.Term t;
     418        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    409419          term = null;
    410420          return false;
     
    413423          return true;
    414424        }
    415       } if (node.Symbol is Tangent) {
    416         AutoDiff.Term t;
    417         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     425      }
     426      if (node.Symbol is Tangent) {
     427        AutoDiff.Term t;
     428        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    418429          term = null;
    419430          return false;
     
    422433          return true;
    423434        }
    424       } if (node.Symbol is Erf) {
    425         AutoDiff.Term t;
    426         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     435      }
     436      if (node.Symbol is Erf) {
     437        AutoDiff.Term t;
     438        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    427439          term = null;
    428440          return false;
     
    431443          return true;
    432444        }
    433       } if (node.Symbol is Norm) {
    434         AutoDiff.Term t;
    435         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     445      }
     446      if (node.Symbol is Norm) {
     447        AutoDiff.Term t;
     448        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
    436449          term = null;
    437450          return false;
     
    447460        variables.Add(alpha);
    448461        AutoDiff.Term branchTerm;
    449         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
     462        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out branchTerm)) {
    450463          term = branchTerm * alpha + beta;
    451464          return true;
Note: See TracChangeset for help on using the changeset viewer.