Changeset 14539


Ignore:
Timestamp:
01/04/17 13:58:05 (3 years ago)
Author:
gkronber
Message:

#2650: extended simplifier to pass all new unit tests for factors and binary factors

Location:
branches/symbreg-factors-2650
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Importer/SymbolicExpressionImporter.cs

    r14535 r14539  
    228228      // create a set of (virtual) values to match the number of weights
    229229      t.Symbol.VariableNames = new string[] { t.VariableName };
    230       t.Symbol.VariableValues = new KeyValuePair<string, List<string>>[] { new KeyValuePair<string, List<string>>(t.VariableName, weights.Select((_, i) => "x" + i).ToList()) };
     230      t.Symbol.VariableValues = new KeyValuePair<string, List<string>>[] { new KeyValuePair<string, List<string>>(t.VariableName, weights.Select((_, i) => "X" + i).ToList()) };
    231231      return t;
    232232    }
  • branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeSimplifier.cs

    r14535 r14539  
    659659      } else if(IsBinFactor(node)) {
    660660        var binFactor = node as BinaryFactorVariableTreeNode;
    661         return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Cos(binFactor.Weight));
     661        // cos(0) = 1 see similar case for Exp(binfactor)
     662        return MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Cos(binFactor.Weight) - 1),
     663          MakeConstant(1.0));
    662664      } else {
    663665        var cosNode = cosineSymbol.CreateTreeNode();
     
    675677        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Exp(w)));
    676678      } else if(IsBinFactor(node)) {
     679        // exp( binfactor w val=a) = if(val=a) exp(w) else exp(0) = binfactor( (exp(w) - 1) val a) + 1
    677680        var binFactor = node as BinaryFactorVariableTreeNode;
    678         return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Exp(binFactor.Weight));
     681        return
     682          MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Exp(binFactor.Weight) - 1), MakeConstant(1.0));
    679683      } else if(IsLog(node)) {
    680684        return node.GetSubtree(0);
     
    696700        var factNode = node as FactorVariableTreeNode;
    697701        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Log(w)));
    698       } else if(IsBinFactor(node)) {
    699         var binFactor = node as BinaryFactorVariableTreeNode;
    700         return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Log(binFactor.Weight));
    701702      } else if(IsExp(node)) {
    702703        return node.GetSubtree(0);
     
    881882        var node1 = b as FactorVariableTreeNode;
    882883        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u / v));
     884      } else if(IsFactor(a) && IsBinFactor(b) && ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     885        var node0 = a as FactorVariableTreeNode;
     886        var node1 = b as BinaryFactorVariableTreeNode;
     887        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     888        var wi = Array.IndexOf(varValues, node1.VariableValue);
     889        if(wi < 0) throw new ArgumentException();
     890        var newWeighs = new double[varValues.Length];
     891        node0.Weights.CopyTo(newWeighs, 0);
     892        for(int i = 0; i < newWeighs.Length; i++)
     893          if(wi == i) newWeighs[i] /= node1.Weight;
     894          else newWeighs[i] /= 0.0;
     895        return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
    883896      } else if(IsFactor(a)) {
    884897        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
     
    935948        var node1 = b as FactorVariableTreeNode;
    936949        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u + v));
     950      } else if(IsBinFactor(a) && IsFactor(b)) {
     951        return MakeSum(b, a);
     952      } else if(IsFactor(a) && IsBinFactor(b) &&
     953        ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     954        var node0 = a as FactorVariableTreeNode;
     955        var node1 = b as BinaryFactorVariableTreeNode;
     956        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     957        var wi = Array.IndexOf(varValues, node1.VariableValue);
     958        if(wi < 0) throw new ArgumentException();
     959        var newWeighs = new double[varValues.Length];
     960        node0.Weights.CopyTo(newWeighs, 0);
     961        newWeighs[wi] += node1.Weight;
     962        return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
    937963      } else if(IsAddition(a) && IsAddition(b)) {
    938964        // merge additions
     
    10791105        var node1 = b as ConstantTreeNode;
    10801106        return MakeBinFactor(node0.Symbol, node0.VariableName, node0.VariableValue, node0.Weight * node1.Value);
     1107      } else if(IsBinFactor(a) && IsFactor(b)) {
     1108        return MakeProduct(b, a);
     1109      } else if(IsFactor(a) && IsBinFactor(b) &&
     1110        ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     1111        var node0 = a as FactorVariableTreeNode;
     1112        var node1 = b as BinaryFactorVariableTreeNode;
     1113        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     1114        var wi = Array.IndexOf(varValues, node1.VariableValue);
     1115        if(wi < 0) throw new ArgumentException();
     1116        return MakeBinFactor(node1.Symbol, node1.VariableName, node1.VariableValue, node1.Weight * node0.Weights[wi]);
    10811117      } else if(IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
    10821118        // $ * 1.0 => $
  • branches/symbreg-factors-2650/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/SymbolicDataAnalysisExpressionTreeSimplifierTest.cs

    r14535 r14539  
    161161      AssertEqualAfterSimplification("(* 2.0 (factor a 4.0 6.0))", "(factor a 8.0 12.0)");
    162162      AssertEqualAfterSimplification("(* (factor a 4.0 6.0) 2.0)", "(factor a 8.0 12.0)");
    163       AssertEqualAfterSimplification("(* (factor a 4.0 6.0) (variable 2.0 a))", "(* (factor a 8.0 12.0) (variable 1.0 a))"); // not possible (a is used as factor and double variable)
     163      AssertEqualAfterSimplification("(* (factor a 4.0 6.0) (variable 2.0 a))", "(* (factor a 8.0 12.0) (variable 1.0 a))"); // not possible (a is used as factor and double variable) interpreter will fail
    164164      AssertEqualAfterSimplification(
    165165        "(log (factor a 10.0 100.0))",
     
    192192
    193193      AssertEqualAfterSimplification("(+ 3.0 (binfactor a val 4.0 ))", "(+ (binfactor a val 4.0 ) 3.0))"); // not allowed
    194       AssertEqualAfterSimplification("(- 3.0 (binfactor a val 4.0 ))", "(- 3.0 (binfactor a val 4.0 ))"); // not allowed
     194      AssertEqualAfterSimplification("(- 3.0 (binfactor a val 4.0 ))", "(+ (binfactor a val -4.0 ) 3.0)");
    195195      AssertEqualAfterSimplification("(+ (binfactor a val 4.0 ) 3.0)", "(+ (binfactor a val 4.0 ) 3.0)");  // not allowed
    196       AssertEqualAfterSimplification("(- (binfactor a val 4.0 ) 3.0)", "(- (binfactor a val 4.0 ) 3.0)");  // not allowed
     196      AssertEqualAfterSimplification("(- (binfactor a val 4.0 ) 3.0)", "(+ (binfactor a val 4.0 ) -3.0)");
    197197      AssertEqualAfterSimplification("(* 2.0 (binfactor a val 4.0))", "(binfactor a val 8.0 )");
    198198      AssertEqualAfterSimplification("(* (binfactor a val 4.0) 2.0)", "(binfactor a val 8.0 )");
    199       AssertEqualAfterSimplification("(* (binfactor a val 4.0) (variable 2.0 a))", "(* (binfactor a val 8.0) (variable 1.0 a))");  // not possible (a is used as factor and double variable)
    200       AssertEqualAfterSimplification("(log (binfactor a val 10.0))", "(log (binfactor a val 10.0))"); // not allowed
    201       AssertEqualAfterSimplification("(exp (binfactor a val 3.0))", "(exp (binfactor a val 3.0))"); // not allowed
     199      AssertEqualAfterSimplification("(* (binfactor a val 4.0) (variable 2.0 a))", "(* (binfactor a val 1.0) (variable 1.0 a) 8.0)");   // not possible (a is used as factor and double variable) interpreter will fail
     200      AssertEqualAfterSimplification("(log (binfactor a val 10.0))", "(log (binfactor a val 10.0))"); // not allowed (log(0))
     201
     202      // exp( binfactor w val=a) = if(val=a) exp(w) else exp(0) = binfactor( (exp(w) - 1) val a) + 1
     203      AssertEqualAfterSimplification("(exp (binfactor a val 3.0))",
     204        string.Format(CultureInfo.InvariantCulture, "(+ (binfactor a val {0}) 1.0)", Math.Exp(3.0) - 1)
     205        );
    202206      AssertEqualAfterSimplification("(sqrt (binfactor a val 16.0))", "(binfactor a val 4.0))"); // sqrt(0) = 0
    203207      AssertEqualAfterSimplification("(sqr (binfactor a val 3.0))", "(binfactor a val 9.0))"); // 0*0 = 0
     
    207211      AssertEqualAfterSimplification("(sin (binfactor a val 2.0) )",
    208212        string.Format(CultureInfo.InvariantCulture, "(binfactor a val {0}))", Math.Sin(2.0))); // sin(0) = 0
    209       AssertEqualAfterSimplification("(cos (binfactor a val 2.0) )", "(cos (binfactor a val 2.0) )"); // not allowed
     213      AssertEqualAfterSimplification("(cos (binfactor a val 2.0) )",
     214        string.Format(CultureInfo.InvariantCulture, "(+ (binfactor a val {0}) 1.0)", Math.Cos(2.0) - 1)); // cos(0) = 1
    210215      AssertEqualAfterSimplification("(tan (binfactor a val 2.0) )",
    211216        string.Format(CultureInfo.InvariantCulture, "(binfactor a val {0}))", Math.Tan(2.0))); // tan(0) = 0
    212217
    213218      // combination of factor and binfactor
    214       // TODO
    215219      AssertEqualAfterSimplification("(+ (binfactor a x0 2.0) (factor a 2.0 3.0))", "(factor a 4.0 3.0)");
    216       AssertEqualAfterSimplification("(* (binfactor a x1 2.0) (factor a 2.0 3.0))", "(binfactor a x1 4.0)"); // all other values have weight zero in binfactor
     220      AssertEqualAfterSimplification("(+ (factor a 2.0 3.0) (binfactor a x0 2.0))", "(factor a 4.0 3.0)");
     221      AssertEqualAfterSimplification("(* (binfactor a x1 2.0) (factor a 2.0 3.0))", "(binfactor a x1 6.0)"); // all other values have weight zero in binfactor
     222      AssertEqualAfterSimplification("(* (factor a 2.0 3.0) (binfactor a x1 2.0))", "(binfactor a x1 6.0)"); // all other values have weight zero in binfactor
    217223      AssertEqualAfterSimplification("(/ (binfactor a x0 2.0) (factor a 2.0 3.0))", "(binfactor a x0 1.0)");
    218224      AssertEqualAfterSimplification("(/ (factor a 2.0 3.0) (binfactor a x0 2.0))",
Note: See TracChangeset for help on using the changeset viewer.