Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/04/17 17:52:44 (8 years ago)
Author:
gkronber
Message:

#2650: merged the factors branch into trunk

Location:
trunk/sources
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14400 r14826  
    176176
    177177
    178     public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
    179 
    180       List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
    181       List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    182       List<string> variableNames = new List<string>();
    183       List<int> lags = new List<int>();
     178
     179    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     180      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     181      int maxIterations, bool updateVariableWeights = true,
     182      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     183      bool updateConstantsInTree = true) {
     184
     185      // numeric constants in the tree become variables for constant opt
     186      // variables in the tree become parameters (fixed values) for constant opt
     187      // for each parameter (variable in the original tree) we store the
     188      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
     189      // A dictionary is used to find parameters
     190      var variables = new List<AutoDiff.Variable>();
     191      var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
    184192
    185193      AutoDiff.Term func;
    186       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out func))
     194      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func))
    187195        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    188       if (variableNames.Count == 0) return 0.0;
    189 
    190       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
    191 
    192       List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
     196      if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
     197
     198      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
     199      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
     200
     201      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
    193202      if (updateVariableWeights)
    194203        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    195204      else
    196         terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>());
     205        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
     206          (tree.Root.IterateNodesPrefix()
     207          .OfType<SymbolicExpressionTreeTerminalNode>()
     208          .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
    197209
    198210      //extract inital constants
     
    205217          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    206218          VariableTreeNode variableTreeNode = node as VariableTreeNode;
     219          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
     220          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    207221          if (constantTreeNode != null)
    208222            c[i++] = constantTreeNode.Value;
    209223          else if (updateVariableWeights && variableTreeNode != null)
    210224            c[i++] = variableTreeNode.Weight;
     225          else if (updateVariableWeights && binFactorVarTreeNode != null)
     226            c[i++] = binFactorVarTreeNode.Weight;
     227          else if (factorVarTreeNode != null) {
     228            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
     229            foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
     230          }
    211231        }
    212232      }
     
    216236      alglib.lsfitstate state;
    217237      alglib.lsfitreport rep;
    218       int info;
     238      int retVal;
    219239
    220240      IDataset ds = problemData.Dataset;
    221       double[,] x = new double[rows.Count(), variableNames.Count];
     241      double[,] x = new double[rows.Count(), parameters.Count];
    222242      int row = 0;
    223243      foreach (var r in rows) {
    224         for (int col = 0; col < variableNames.Count; col++) {
    225           int lag = lags[col];
    226           x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag);
     244        int col = 0;
     245        foreach (var kvp in parameterEntries) {
     246          var info = kvp.Key;
     247          int lag = info.lag;
     248          if (ds.VariableHasType<double>(info.variableName)) {
     249            x[row, col] = ds.GetDoubleValue(info.variableName, r + lag);
     250          } else if (ds.VariableHasType<string>(info.variableName)) {
     251            x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
     252          } else throw new InvalidProgramException("found a variable of unknown type");
     253          col++;
    227254        }
    228255        row++;
     
    241268        //alglib.lsfitsetgradientcheck(state, 0.001);
    242269        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
    243         alglib.lsfitresults(state, out info, out c, out rep);
     270        alglib.lsfitresults(state, out retVal, out c, out rep);
    244271      } catch (ArithmeticException) {
    245272        return originalQuality;
     
    248275      }
    249276
    250       //info == -7  => constant optimization failed due to wrong gradient
    251       if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
     277      //retVal == -7  => constant optimization failed due to wrong gradient
     278      if (retVal != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
    252279      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
    253280
     
    265292        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    266293        VariableTreeNode variableTreeNode = node as VariableTreeNode;
     294        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
     295        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    267296        if (constantTreeNode != null)
    268297          constantTreeNode.Value = constants[i++];
    269298        else if (updateVariableWeights && variableTreeNode != null)
    270299          variableTreeNode.Weight = constants[i++];
     300        else if (updateVariableWeights && binFactorVarTreeNode != null)
     301          binFactorVarTreeNode.Weight = constants[i++];
     302        else if (factorVarTreeNode != null) {
     303          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     304            factorVarTreeNode.Weights[j] = constants[i++];
     305        }
    271306      }
    272307    }
     
    286321    }
    287322
    288     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, List<int> lags, bool updateVariableWeights, out AutoDiff.Term term) {
     323    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node,
     324      List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
     325      bool updateVariableWeights, out AutoDiff.Term term) {
    289326      if (node.Symbol is Constant) {
    290327        var var = new AutoDiff.Variable();
     
    293330        return true;
    294331      }
    295       if (node.Symbol is Variable) {
    296         var varNode = node as VariableTreeNode;
    297         var par = new AutoDiff.Variable();
    298         parameters.Add(par);
    299         variableNames.Add(varNode.VariableName);
    300         lags.Add(0);
     332      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
     333        var varNode = node as VariableTreeNodeBase;
     334        var factorVarNode = node as BinaryFactorVariableTreeNode;
     335        // factor variable values are only 0 or 1 and set in x accordingly
     336        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
     337        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
    301338
    302339        if (updateVariableWeights) {
     
    309346        return true;
    310347      }
     348      if (node.Symbol is FactorVariable) {
     349        var factorVarNode = node as FactorVariableTreeNode;
     350        var products = new List<Term>();
     351        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
     352          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
     353
     354          var wVar = new AutoDiff.Variable();
     355          variables.Add(wVar);
     356
     357          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
     358        }
     359        term = AutoDiff.TermBuilder.Sum(products);
     360        return true;
     361      }
    311362      if (node.Symbol is LaggedVariable) {
    312363        var varNode = node as LaggedVariableTreeNode;
    313         var par = new AutoDiff.Variable();
    314         parameters.Add(par);
    315         variableNames.Add(varNode.VariableName);
    316         lags.Add(varNode.Lag);
     364        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    317365
    318366        if (updateVariableWeights) {
     
    329377        foreach (var subTree in node.Subtrees) {
    330378          AutoDiff.Term t;
    331           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     379          if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    332380            term = null;
    333381            return false;
     
    342390        for (int i = 0; i < node.SubtreeCount; i++) {
    343391          AutoDiff.Term t;
    344           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     392          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) {
    345393            term = null;
    346394            return false;
     
    357405        foreach (var subTree in node.Subtrees) {
    358406          AutoDiff.Term t;
    359           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     407          if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    360408            term = null;
    361409            return false;
     
    372420        foreach (var subTree in node.Subtrees) {
    373421          AutoDiff.Term t;
    374           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     422          if (!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    375423            term = null;
    376424            return false;
     
    384432      if (node.Symbol is Logarithm) {
    385433        AutoDiff.Term t;
    386         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     434        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    387435          term = null;
    388436          return false;
     
    394442      if (node.Symbol is Exponential) {
    395443        AutoDiff.Term t;
    396         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     444        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    397445          term = null;
    398446          return false;
     
    404452      if (node.Symbol is Square) {
    405453        AutoDiff.Term t;
    406         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     454        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    407455          term = null;
    408456          return false;
     
    414462      if (node.Symbol is SquareRoot) {
    415463        AutoDiff.Term t;
    416         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     464        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    417465          term = null;
    418466          return false;
     
    424472      if (node.Symbol is Sine) {
    425473        AutoDiff.Term t;
    426         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     474        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    427475          term = null;
    428476          return false;
     
    434482      if (node.Symbol is Cosine) {
    435483        AutoDiff.Term t;
    436         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     484        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    437485          term = null;
    438486          return false;
     
    444492      if (node.Symbol is Tangent) {
    445493        AutoDiff.Term t;
    446         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     494        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    447495          term = null;
    448496          return false;
     
    454502      if (node.Symbol is Erf) {
    455503        AutoDiff.Term t;
    456         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     504        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    457505          term = null;
    458506          return false;
     
    464512      if (node.Symbol is Norm) {
    465513        AutoDiff.Term t;
    466         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out t)) {
     514        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    467515          term = null;
    468516          return false;
     
    478526        variables.Add(alpha);
    479527        AutoDiff.Term branchTerm;
    480         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, updateVariableWeights, out branchTerm)) {
     528        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) {
    481529          term = branchTerm * alpha + beta;
    482530          return true;
     
    488536      term = null;
    489537      return false;
     538    }
     539
     540    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
     541    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
     542    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
     543      string varName, string varValue = "", int lag = 0) {
     544      var data = new DataForVariable(varName, varValue, lag);
     545
     546      AutoDiff.Variable par = null;
     547      if (!parameters.TryGetValue(data, out par)) {
     548        // not found -> create new parameter and entries in names and values lists
     549        par = new AutoDiff.Variable();
     550        parameters.Add(data, par);
     551      }
     552      return par;
    490553    }
    491554
     
    495558        where
    496559         !(n.Symbol is Variable) &&
     560         !(n.Symbol is BinaryFactorVariable) &&
     561         !(n.Symbol is FactorVariable) &&
    497562         !(n.Symbol is LaggedVariable) &&
    498563         !(n.Symbol is Constant) &&
     
    515580      return !containsUnknownSymbol;
    516581    }
     582
     583
     584    #region helper class
     585    private class DataForVariable {
     586      public readonly string variableName;
     587      public readonly string variableValue; // for factor vars
     588      public readonly int lag;
     589
     590      public DataForVariable(string varName, string varValue, int lag) {
     591        this.variableName = varName;
     592        this.variableValue = varValue;
     593        this.lag = lag;
     594      }
     595
     596      public override bool Equals(object obj) {
     597        var other = obj as DataForVariable;
     598        if (other == null) return false;
     599        return other.variableName.Equals(this.variableName) &&
     600               other.variableValue.Equals(this.variableValue) &&
     601               other.lag == this.lag;
     602      }
     603
     604      public override int GetHashCode() {
     605        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
     606      }
     607    }
     608    #endregion
    517609  }
    518610}
Note: See TracChangeset for help on using the changeset viewer.