Ignore:
Timestamp:
04/04/17 17:52:44 (6 months ago)
Author:
gkronber
Message:

#2650: merged the factors branch into trunk

Location:
trunk/sources
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic

  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeSimplifier.cs

    r14400 r14826  
    11#region License Information
     2
    23/* HeuristicLab
    34 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     
    1819 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
    1920 */
     21
    2022#endregion
    2123
    2224using System;
    2325using System.Collections.Generic;
     26using System.Diagnostics;
    2427using System.Linq;
    2528using HeuristicLab.Common;
     29using HeuristicLab.Core;
    2630using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2731
     
    3236  public class SymbolicDataAnalysisExpressionTreeSimplifier {
    3337    private Addition addSymbol = new Addition();
    34     private Subtraction subSymbol = new Subtraction();
    3538    private Multiplication mulSymbol = new Multiplication();
    3639    private Division divSymbol = new Division();
     
    6265      ISymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
    6366      rootNode.AddSubtree(GetSimplifiedTree(macroExpandedTree));
     67
     68#if DEBUG
     69      // check that each node is only referenced once
     70      var nodes = rootNode.IterateNodesPrefix().ToArray();
     71      foreach(var n in nodes) if(nodes.Count(ni => ni == n) > 1) throw new InvalidOperationException();
     72#endif
    6473      return new SymbolicExpressionTree(rootNode);
    6574    }
    6675
    6776    // the argumentTrees list contains already expanded trees used as arguments for invocations
    68     private ISymbolicExpressionTreeNode MacroExpand(ISymbolicExpressionTreeNode root, ISymbolicExpressionTreeNode node, IList<ISymbolicExpressionTreeNode> argumentTrees) {
     77    private ISymbolicExpressionTreeNode MacroExpand(ISymbolicExpressionTreeNode root, ISymbolicExpressionTreeNode node,
     78      IList<ISymbolicExpressionTreeNode> argumentTrees) {
    6979      List<ISymbolicExpressionTreeNode> subtrees = new List<ISymbolicExpressionTreeNode>(node.Subtrees);
    7080      while (node.SubtreeCount > 0) node.RemoveSubtree(0);
     
    98108    }
    99109
    100 
    101110    #region symbol predicates
     111
    102112    // arithmetic
    103113    private bool IsDivision(ISymbolicExpressionTreeNode node) {
     
    120130      return node.Symbol is Average;
    121131    }
     132
    122133    // exponential
    123134    private bool IsLog(ISymbolicExpressionTreeNode node) {
    124135      return node.Symbol is Logarithm;
    125136    }
     137
    126138    private bool IsExp(ISymbolicExpressionTreeNode node) {
    127139      return node.Symbol is Exponential;
    128140    }
     141
    129142    private bool IsRoot(ISymbolicExpressionTreeNode node) {
    130143      return node.Symbol is Root;
    131144    }
     145
    132146    private bool IsSquare(ISymbolicExpressionTreeNode node) {
    133147      return node.Symbol is Square;
    134148    }
     149
    135150    private bool IsSquareRoot(ISymbolicExpressionTreeNode node) {
    136151      return node.Symbol is SquareRoot;
    137152    }
     153
    138154    private bool IsPower(ISymbolicExpressionTreeNode node) {
    139155      return node.Symbol is Power;
    140156    }
     157
    141158    // trigonometric
    142159    private bool IsSine(ISymbolicExpressionTreeNode node) {
    143160      return node.Symbol is Sine;
    144161    }
     162
    145163    private bool IsCosine(ISymbolicExpressionTreeNode node) {
    146164      return node.Symbol is Cosine;
    147165    }
     166
    148167    private bool IsTangent(ISymbolicExpressionTreeNode node) {
    149168      return node.Symbol is Tangent;
    150169    }
     170
    151171    // boolean
    152172    private bool IsIfThenElse(ISymbolicExpressionTreeNode node) {
    153173      return node.Symbol is IfThenElse;
    154174    }
     175
    155176    private bool IsAnd(ISymbolicExpressionTreeNode node) {
    156177      return node.Symbol is And;
    157178    }
     179
    158180    private bool IsOr(ISymbolicExpressionTreeNode node) {
    159181      return node.Symbol is Or;
    160182    }
     183
    161184    private bool IsNot(ISymbolicExpressionTreeNode node) {
    162185      return node.Symbol is Not;
    163186    }
     187
    164188    // comparison
    165189    private bool IsGreaterThan(ISymbolicExpressionTreeNode node) {
    166190      return node.Symbol is GreaterThan;
    167191    }
     192
    168193    private bool IsLessThan(ISymbolicExpressionTreeNode node) {
    169194      return node.Symbol is LessThan;
     
    183208    }
    184209
     210    private bool IsVariableBase(ISymbolicExpressionTreeNode node) {
     211      return node is VariableTreeNodeBase;
     212    }
     213
     214    private bool IsFactor(ISymbolicExpressionTreeNode node) {
     215      return node is FactorVariableTreeNode;
     216    }
     217
     218    private bool IsBinFactor(ISymbolicExpressionTreeNode node) {
     219      return node is BinaryFactorVariableTreeNode;
     220    }
     221
    185222    private bool IsConstant(ISymbolicExpressionTreeNode node) {
    186223      return node.Symbol is Constant;
     
    191228      return node.Symbol is TimeLag;
    192229    }
     230
    193231    private bool IsIntegral(ISymbolicExpressionTreeNode node) {
    194232      return node.Symbol is Integral;
     
    203241    /// <returns></returns>
    204242    public ISymbolicExpressionTreeNode GetSimplifiedTree(ISymbolicExpressionTreeNode original) {
    205       if (IsConstant(original) || IsVariable(original)) {
     243      if (IsConstant(original) || IsVariableBase(original)) {
    206244        return (ISymbolicExpressionTreeNode)original.Clone();
    207245      } else if (IsAddition(original)) {
     
    254292    }
    255293
    256 
    257294    #region specific simplification routines
     295
    258296    private ISymbolicExpressionTreeNode SimplifyAny(ISymbolicExpressionTreeNode original) {
    259297      // can't simplify this function but simplify all subtrees
     
    303341        var remaining = original.Subtrees.Skip(2);
    304342        return
    305           MakeProduct(GetSimplifiedTree(first), Invert(remaining.Aggregate(GetSimplifiedTree(second), (a, b) => MakeProduct(a, GetSimplifiedTree(b)))));
     343          MakeProduct(GetSimplifiedTree(first),
     344            Invert(remaining.Aggregate(GetSimplifiedTree(second), (a, b) => MakeProduct(a, GetSimplifiedTree(b)))));
    306345      }
    307346    }
     
    344383      return MakeNot(GetSimplifiedTree(original.GetSubtree(0)));
    345384    }
     385
    346386    private ISymbolicExpressionTreeNode SimplifyOr(ISymbolicExpressionTreeNode original) {
    347387      return original.Subtrees
     
    349389        .Aggregate(MakeOr);
    350390    }
     391
    351392    private ISymbolicExpressionTreeNode SimplifyAnd(ISymbolicExpressionTreeNode original) {
    352393      return original.Subtrees
     
    354395        .Aggregate(MakeAnd);
    355396    }
     397
    356398    private ISymbolicExpressionTreeNode SimplifyLessThan(ISymbolicExpressionTreeNode original) {
    357399      return MakeLessThan(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
    358400    }
     401
    359402    private ISymbolicExpressionTreeNode SimplifyGreaterThan(ISymbolicExpressionTreeNode original) {
    360403      return MakeGreaterThan(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
    361404    }
     405
    362406    private ISymbolicExpressionTreeNode SimplifyIfThenElse(ISymbolicExpressionTreeNode original) {
    363       return MakeIfThenElse(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)), GetSimplifiedTree(original.GetSubtree(2)));
    364     }
     407      return MakeIfThenElse(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)),
     408        GetSimplifiedTree(original.GetSubtree(2)));
     409    }
     410
    365411    private ISymbolicExpressionTreeNode SimplifyTangent(ISymbolicExpressionTreeNode original) {
    366412      return MakeTangent(GetSimplifiedTree(original.GetSubtree(0)));
    367413    }
     414
    368415    private ISymbolicExpressionTreeNode SimplifyCosine(ISymbolicExpressionTreeNode original) {
    369416      return MakeCosine(GetSimplifiedTree(original.GetSubtree(0)));
    370417    }
     418
    371419    private ISymbolicExpressionTreeNode SimplifySine(ISymbolicExpressionTreeNode original) {
    372420      return MakeSine(GetSimplifiedTree(original.GetSubtree(0)));
    373421    }
     422
    374423    private ISymbolicExpressionTreeNode SimplifyExp(ISymbolicExpressionTreeNode original) {
    375424      return MakeExp(GetSimplifiedTree(original.GetSubtree(0)));
    376425    }
     426
    377427    private ISymbolicExpressionTreeNode SimplifySquare(ISymbolicExpressionTreeNode original) {
    378428      return MakeSquare(GetSimplifiedTree(original.GetSubtree(0)));
    379429    }
     430
    380431    private ISymbolicExpressionTreeNode SimplifySquareRoot(ISymbolicExpressionTreeNode original) {
    381432      return MakeSquareRoot(GetSimplifiedTree(original.GetSubtree(0)));
     
    385436      return MakeLog(GetSimplifiedTree(original.GetSubtree(0)));
    386437    }
     438
    387439    private ISymbolicExpressionTreeNode SimplifyRoot(ISymbolicExpressionTreeNode original) {
    388440      return MakeRoot(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
     
    392444      return MakePower(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
    393445    }
     446
    394447    private ISymbolicExpressionTreeNode SimplifyTimeLag(ISymbolicExpressionTreeNode original) {
    395448      var laggedTreeNode = original as ILaggedTreeNode;
     
    401454      }
    402455    }
     456
    403457    private ISymbolicExpressionTreeNode SimplifyIntegral(ISymbolicExpressionTreeNode original) {
    404458      var laggedTreeNode = original as ILaggedTreeNode;
     
    414468
    415469    #region low level tree restructuring
     470
    416471    private ISymbolicExpressionTreeNode MakeTimeLag(ISymbolicExpressionTreeNode subtree, int lag) {
    417472      if (lag == 0) return subtree;
     
    444499      } else if (!IsBoolean(t)) {
    445500        var gtNode = gtSymbol.CreateTreeNode();
    446         gtNode.AddSubtree(t); gtNode.AddSubtree(MakeConstant(0.0));
     501        gtNode.AddSubtree(t);
     502        gtNode.AddSubtree(MakeConstant(0.0));
    447503        var notNode = notSymbol.CreateTreeNode();
    448504        notNode.AddSubtree(gtNode);
     
    484540      }
    485541    }
     542
    486543    private ISymbolicExpressionTreeNode MakeAnd(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
    487544      if (IsConstant(a) && IsConstant(b)) {
     
    513570      }
    514571    }
    515     private ISymbolicExpressionTreeNode MakeLessThan(ISymbolicExpressionTreeNode leftSide, ISymbolicExpressionTreeNode rightSide) {
     572
     573    private ISymbolicExpressionTreeNode MakeLessThan(ISymbolicExpressionTreeNode leftSide,
     574      ISymbolicExpressionTreeNode rightSide) {
    516575      if (IsConstant(leftSide) && IsConstant(rightSide)) {
    517576        var lsConst = leftSide as ConstantTreeNode;
     
    526585      }
    527586    }
    528     private ISymbolicExpressionTreeNode MakeGreaterThan(ISymbolicExpressionTreeNode leftSide, ISymbolicExpressionTreeNode rightSide) {
     587
     588    private ISymbolicExpressionTreeNode MakeGreaterThan(ISymbolicExpressionTreeNode leftSide,
     589      ISymbolicExpressionTreeNode rightSide) {
    529590      if (IsConstant(leftSide) && IsConstant(rightSide)) {
    530591        var lsConst = leftSide as ConstantTreeNode;
     
    539600      }
    540601    }
    541     private ISymbolicExpressionTreeNode MakeIfThenElse(ISymbolicExpressionTreeNode condition, ISymbolicExpressionTreeNode trueBranch, ISymbolicExpressionTreeNode falseBranch) {
     602
     603    private ISymbolicExpressionTreeNode MakeIfThenElse(ISymbolicExpressionTreeNode condition,
     604      ISymbolicExpressionTreeNode trueBranch, ISymbolicExpressionTreeNode falseBranch) {
    542605      if (IsConstant(condition)) {
    543606        var constT = condition as ConstantTreeNode;
     
    550613        } else {
    551614          var gtNode = gtSymbol.CreateTreeNode();
    552           gtNode.AddSubtree(condition); gtNode.AddSubtree(MakeConstant(0.0));
     615          gtNode.AddSubtree(condition);
     616          gtNode.AddSubtree(MakeConstant(0.0));
    553617          ifNode.AddSubtree(gtNode);
    554618        }
     
    563627        var constT = node as ConstantTreeNode;
    564628        return MakeConstant(Math.Sin(constT.Value));
     629      } else if (IsFactor(node)) {
     630        var factor = node as FactorVariableTreeNode;
     631        return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Sin));
     632      } else if (IsBinFactor(node)) {
     633        var binFactor = node as BinaryFactorVariableTreeNode;
     634        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sin(binFactor.Weight));
    565635      } else {
    566636        var sineNode = sineSymbol.CreateTreeNode();
     
    569639      }
    570640    }
     641
    571642    private ISymbolicExpressionTreeNode MakeTangent(ISymbolicExpressionTreeNode node) {
    572643      if (IsConstant(node)) {
    573644        var constT = node as ConstantTreeNode;
    574645        return MakeConstant(Math.Tan(constT.Value));
     646      } else if (IsFactor(node)) {
     647        var factor = node as FactorVariableTreeNode;
     648        return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Tan));
     649      } else if (IsBinFactor(node)) {
     650        var binFactor = node as BinaryFactorVariableTreeNode;
     651        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Tan(binFactor.Weight));
    575652      } else {
    576653        var tanNode = tanSymbol.CreateTreeNode();
     
    579656      }
    580657    }
     658
    581659    private ISymbolicExpressionTreeNode MakeCosine(ISymbolicExpressionTreeNode node) {
    582660      if (IsConstant(node)) {
    583661        var constT = node as ConstantTreeNode;
    584662        return MakeConstant(Math.Cos(constT.Value));
     663      } else if (IsFactor(node)) {
     664        var factor = node as FactorVariableTreeNode;
     665        return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Cos));
     666      } else if (IsBinFactor(node)) {
     667        var binFactor = node as BinaryFactorVariableTreeNode;
     668        // cos(0) = 1 see similar case for Exp(binfactor)
     669        return MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Cos(binFactor.Weight) - 1),
     670          MakeConstant(1.0));
    585671      } else {
    586672        var cosNode = cosineSymbol.CreateTreeNode();
     
    589675      }
    590676    }
     677
    591678    private ISymbolicExpressionTreeNode MakeExp(ISymbolicExpressionTreeNode node) {
    592679      if (IsConstant(node)) {
    593680        var constT = node as ConstantTreeNode;
    594681        return MakeConstant(Math.Exp(constT.Value));
     682      } else if (IsFactor(node)) {
     683        var factNode = node as FactorVariableTreeNode;
     684        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Exp(w)));
     685      } else if (IsBinFactor(node)) {
     686        // exp( binfactor w val=a) = if(val=a) exp(w) else exp(0) = binfactor( (exp(w) - 1) val a) + 1
     687        var binFactor = node as BinaryFactorVariableTreeNode;
     688        return
     689          MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Exp(binFactor.Weight) - 1), MakeConstant(1.0));
    595690      } else if (IsLog(node)) {
    596691        return node.GetSubtree(0);
     
    605700      }
    606701    }
     702    private ISymbolicExpressionTreeNode MakeLog(ISymbolicExpressionTreeNode node) {
     703      if (IsConstant(node)) {
     704        var constT = node as ConstantTreeNode;
     705        return MakeConstant(Math.Log(constT.Value));
     706      } else if (IsFactor(node)) {
     707        var factNode = node as FactorVariableTreeNode;
     708        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Log(w)));
     709      } else if (IsExp(node)) {
     710        return node.GetSubtree(0);
     711      } else if (IsSquareRoot(node)) {
     712        return MakeFraction(MakeLog(node.GetSubtree(0)), MakeConstant(2.0));
     713      } else {
     714        var logNode = logSymbol.CreateTreeNode();
     715        logNode.AddSubtree(node);
     716        return logNode;
     717      }
     718    }
    607719
    608720    private ISymbolicExpressionTreeNode MakeSquare(ISymbolicExpressionTreeNode node) {
     
    610722        var constT = node as ConstantTreeNode;
    611723        return MakeConstant(constT.Value * constT.Value);
     724      } else if (IsFactor(node)) {
     725        var factNode = node as FactorVariableTreeNode;
     726        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => w * w));
     727      } else if (IsBinFactor(node)) {
     728        var binFactor = node as BinaryFactorVariableTreeNode;
     729        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, binFactor.Weight * binFactor.Weight);
    612730      } else if (IsSquareRoot(node)) {
    613731        return node.GetSubtree(0);
     
    618736      }
    619737    }
     738
    620739    private ISymbolicExpressionTreeNode MakeSquareRoot(ISymbolicExpressionTreeNode node) {
    621740      if (IsConstant(node)) {
    622741        var constT = node as ConstantTreeNode;
    623742        return MakeConstant(Math.Sqrt(constT.Value));
     743      } else if (IsFactor(node)) {
     744        var factNode = node as FactorVariableTreeNode;
     745        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Sqrt(w)));
     746      } else if (IsBinFactor(node)) {
     747        var binFactor = node as BinaryFactorVariableTreeNode;
     748        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sqrt(binFactor.Weight));
    624749      } else if (IsSquare(node)) {
    625750        return node.GetSubtree(0);
     
    631756    }
    632757
    633     private ISymbolicExpressionTreeNode MakeLog(ISymbolicExpressionTreeNode node) {
    634       if (IsConstant(node)) {
    635         var constT = node as ConstantTreeNode;
    636         return MakeConstant(Math.Log(constT.Value));
    637       } else if (IsExp(node)) {
    638         return node.GetSubtree(0);
    639       } else if (IsSquareRoot(node)) {
    640         return MakeFraction(MakeLog(node.GetSubtree(0)), MakeConstant(2.0));
    641       } else {
    642         var logNode = logSymbol.CreateTreeNode();
    643         logNode.AddSubtree(node);
    644         return logNode;
    645       }
    646     }
    647758    private ISymbolicExpressionTreeNode MakeRoot(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
    648759      if (IsConstant(a) && IsConstant(b)) {
     
    650761        var constB = b as ConstantTreeNode;
    651762        return MakeConstant(Math.Pow(constA.Value, 1.0 / Math.Round(constB.Value)));
     763      } else if (IsFactor(a) && IsConstant(b)) {
     764        var factNode = a as FactorVariableTreeNode;
     765        var constNode = b as ConstantTreeNode;
     766        return MakeFactor(factNode.Symbol, factNode.VariableName,
     767          factNode.Weights.Select(w => Math.Pow(w, 1.0 / Math.Round(constNode.Value))));
     768      } else if (IsBinFactor(a) && IsConstant(b)) {
     769        var binFactor = a as BinaryFactorVariableTreeNode;
     770        var constNode = b as ConstantTreeNode;
     771        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Pow(binFactor.Weight, 1.0 / Math.Round(constNode.Value)));
     772      } else if (IsConstant(a) && IsFactor(b)) {
     773        var constNode = a as ConstantTreeNode;
     774        var factNode = b as FactorVariableTreeNode;
     775        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(constNode.Value, 1.0 / Math.Round(w))));
     776      } else if (IsConstant(a) && IsBinFactor(b)) {
     777        var constNode = a as ConstantTreeNode;
     778        var factNode = b as BinaryFactorVariableTreeNode;
     779        return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, Math.Pow(constNode.Value, 1.0 / Math.Round(factNode.Weight)));
     780      } else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
     781        var node0 = a as FactorVariableTreeNode;
     782        var node1 = b as FactorVariableTreeNode;
     783        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => Math.Pow(u, 1.0 / Math.Round(v))));
    652784      } else if (IsConstant(b)) {
    653785        var constB = b as ConstantTreeNode;
     
    677809      }
    678810    }
     811
     812
    679813    private ISymbolicExpressionTreeNode MakePower(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
    680814      if (IsConstant(a) && IsConstant(b)) {
     
    682816        var constB = b as ConstantTreeNode;
    683817        return MakeConstant(Math.Pow(constA.Value, Math.Round(constB.Value)));
     818      } else if (IsFactor(a) && IsConstant(b)) {
     819        var factNode = a as FactorVariableTreeNode;
     820        var constNode = b as ConstantTreeNode;
     821        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(w, Math.Round(constNode.Value))));
     822      } else if (IsBinFactor(a) && IsConstant(b)) {
     823        var binFactor = a as BinaryFactorVariableTreeNode;
     824        var constNode = b as ConstantTreeNode;
     825        return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Pow(binFactor.Weight, Math.Round(constNode.Value)));
     826      } else if (IsConstant(a) && IsFactor(b)) {
     827        var constNode = a as ConstantTreeNode;
     828        var factNode = b as FactorVariableTreeNode;
     829        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(constNode.Value, Math.Round(w))));
     830      } else if (IsConstant(a) && IsBinFactor(b)) {
     831        var constNode = a as ConstantTreeNode;
     832        var factNode = b as BinaryFactorVariableTreeNode;
     833        return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, Math.Pow(constNode.Value, Math.Round(factNode.Weight)));
     834      } else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
     835        var node0 = a as FactorVariableTreeNode;
     836        var node1 = b as FactorVariableTreeNode;
     837        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => Math.Pow(u, Math.Round(v))));
    684838      } else if (IsConstant(b)) {
    685839        var constB = b as ConstantTreeNode;
     
    716870        // fold constants
    717871        return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
    718       } if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
     872      } else if ((IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0))) {
    719873        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
    720       } else if (IsVariable(a) && IsConstant(b)) {
     874      } else if (IsVariableBase(a) && IsConstant(b)) {
    721875        // merge constant values into variable weights
    722876        var constB = ((ConstantTreeNode)b).Value;
    723         ((VariableTreeNode)a).Weight /= constB;
     877        ((VariableTreeNodeBase)a).Weight /= constB;
    724878        return a;
    725       } else if (IsVariable(a) && IsVariable(b) && AreSameVariable(a, b)) {
    726         // cancel variables
     879      } else if (IsFactor(a) && IsConstant(b)) {
     880        var factNode = a as FactorVariableTreeNode;
     881        var constNode = b as ConstantTreeNode;
     882        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => w / constNode.Value));
     883      } else if (IsBinFactor(a) && IsConstant(b)) {
     884        var factNode = a as BinaryFactorVariableTreeNode;
     885        var constNode = b as ConstantTreeNode;
     886        return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, factNode.Weight / constNode.Value);
     887      } else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
     888        var node0 = a as FactorVariableTreeNode;
     889        var node1 = b as FactorVariableTreeNode;
     890        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u / v));
     891      } else if (IsFactor(a) && IsBinFactor(b) && ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     892        var node0 = a as FactorVariableTreeNode;
     893        var node1 = b as BinaryFactorVariableTreeNode;
     894        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     895        var wi = Array.IndexOf(varValues, node1.VariableValue);
     896        if (wi < 0) throw new ArgumentException();
     897        var newWeighs = new double[varValues.Length];
     898        node0.Weights.CopyTo(newWeighs, 0);
     899        for (int i = 0; i < newWeighs.Length; i++)
     900          if (wi == i) newWeighs[i] /= node1.Weight;
     901          else newWeighs[i] /= 0.0;
     902        return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
     903      } else if (IsFactor(a)) {
     904        return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
     905      } else if (IsVariableBase(a) && IsVariableBase(b) && AreSameTypeAndVariable(a, b) && !IsBinFactor(b)) {
     906        // cancel variables (not allowed for bin factors because of division by zero)
    727907        var aVar = a as VariableTreeNode;
    728908        var bVar = b as VariableTreeNode;
     
    731911        return a.Subtrees
    732912          .Select(x => GetSimplifiedTree(x))
    733          .Select(x => MakeFraction(x, b))
    734          .Aggregate((c, d) => MakeSum(c, d));
     913          .Select(x => MakeFraction(x, GetSimplifiedTree(b)))
     914          .Aggregate((c, d) => MakeSum(c, d));
    735915      } else if (IsMultiplication(a) && IsConstant(b)) {
    736916        return MakeProduct(a, Invert(b));
     
    767947        // x + 0 => x
    768948        return a;
     949      } else if (IsFactor(a) && IsConstant(b)) {
     950        var factNode = a as FactorVariableTreeNode;
     951        var constNode = b as ConstantTreeNode;
     952        return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select((w) => w + constNode.Value));
     953      } else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
     954        var node0 = a as FactorVariableTreeNode;
     955        var node1 = b as FactorVariableTreeNode;
     956        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u + v));
     957      } else if (IsBinFactor(a) && IsFactor(b)) {
     958        return MakeSum(b, a);
     959      } else if (IsFactor(a) && IsBinFactor(b) &&
     960        ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     961        var node0 = a as FactorVariableTreeNode;
     962        var node1 = b as BinaryFactorVariableTreeNode;
     963        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     964        var wi = Array.IndexOf(varValues, node1.VariableValue);
     965        if (wi < 0) throw new ArgumentException();
     966        var newWeighs = new double[varValues.Length];
     967        node0.Weights.CopyTo(newWeighs, 0);
     968        newWeighs[wi] += node1.Weight;
     969        return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
    769970      } else if (IsAddition(a) && IsAddition(b)) {
    770971        // merge additions
     
    8291030
    8301031    // makes sure variable symbols in sums are combined
    831     // possible improvement: combine sums of products where the products only reference the same variable
    8321032    private void MergeVariablesInSum(ISymbolicExpressionTreeNode sum) {
    8331033      var subtrees = new List<ISymbolicExpressionTreeNode>(sum.Subtrees);
    8341034      while (sum.Subtrees.Any()) sum.RemoveSubtree(0);
    835       var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
    836                             let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
    837                             group node by node.VariableName + lag into g
     1035      var groupedVarNodes = from node in subtrees.OfType<IVariableTreeNode>()
     1036                            where node.SubtreeCount == 0
     1037                            group node by GroupId(node) into g
    8381038                            select g;
    839       var unchangedSubtrees = subtrees.Where(t => !(t is VariableTreeNode));
     1039      var constant = (from node in subtrees.OfType<ConstantTreeNode>()
     1040                      select node.Value).DefaultIfEmpty(0.0).Sum();
     1041      var unchangedSubtrees = subtrees.Where(t => t.SubtreeCount > 0 || !(t is IVariableTreeNode) && !(t is ConstantTreeNode));
    8401042
    8411043      foreach (var variableNodeGroup in groupedVarNodes) {
    842         var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
    843         var representative = variableNodeGroup.First();
    844         representative.Weight = weightSum;
    845         sum.AddSubtree(representative);
     1044        var firstNode = variableNodeGroup.First();
     1045        if (firstNode is VariableTreeNodeBase) {
     1046          var representative = firstNode as VariableTreeNodeBase;
     1047          var weightSum = variableNodeGroup.Cast<VariableTreeNodeBase>().Select(t => t.Weight).Sum();
     1048          representative.Weight = weightSum;
     1049          sum.AddSubtree(representative);
     1050        } else if (firstNode is FactorVariableTreeNode) {
     1051          var representative = firstNode as FactorVariableTreeNode;
     1052          foreach (var node in variableNodeGroup.Skip(1).Cast<FactorVariableTreeNode>()) {
     1053            for (int j = 0; j < representative.Weights.Length; j++) {
     1054              representative.Weights[j] += node.Weights[j];
     1055            }
     1056          }
     1057          for (int j = 0; j < representative.Weights.Length; j++) {
     1058            representative.Weights[j] += constant;
     1059          }
     1060          sum.AddSubtree(representative);
     1061        }
    8461062      }
    8471063      foreach (var unchangedSubtree in unchangedSubtrees)
    8481064        sum.AddSubtree(unchangedSubtree);
     1065      if (!constant.IsAlmost(0.0)) {
     1066        sum.AddSubtree(MakeConstant(constant));
     1067      }
     1068    }
     1069
     1070    // nodes referencing variables can be grouped if they have
     1071    private string GroupId(IVariableTreeNode node) {
     1072      var binaryFactorNode = node as BinaryFactorVariableTreeNode;
     1073      var factorNode = node as FactorVariableTreeNode;
     1074      var variableNode = node as VariableTreeNode;
     1075      var laggedVarNode = node as LaggedVariableTreeNode;
     1076      if (variableNode != null) {
     1077        return "var " + variableNode.VariableName;
     1078      } else if (binaryFactorNode != null) {
     1079        return "binfactor " + binaryFactorNode.VariableName + " " + binaryFactorNode.VariableValue;
     1080      } else if (factorNode != null) {
     1081        return "factor " + factorNode.VariableName;
     1082      } else if (laggedVarNode != null) {
     1083        return "lagged " + laggedVarNode.VariableName + " " + laggedVarNode.Lag;
     1084      } else {
     1085        throw new NotSupportedException();
     1086      }
    8491087    }
    8501088
     
    8531091      if (IsConstant(a) && IsConstant(b)) {
    8541092        // fold constants
    855         ((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
    856         return a;
     1093        return MakeConstant(((ConstantTreeNode)a).Value * ((ConstantTreeNode)b).Value);
    8571094      } else if (IsConstant(a)) {
    8581095        // a * $ => $ * a
    8591096        return MakeProduct(b, a);
     1097      } else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
     1098        var node0 = a as FactorVariableTreeNode;
     1099        var node1 = b as FactorVariableTreeNode;
     1100        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u * v));
     1101      } else if (IsBinFactor(a) && IsBinFactor(b) && AreSameTypeAndVariable(a, b)) {
     1102        var node0 = a as BinaryFactorVariableTreeNode;
     1103        var node1 = b as BinaryFactorVariableTreeNode;
     1104        return MakeBinFactor(node0.Symbol, node0.VariableName, node0.VariableValue, node0.Weight * node1.Weight);
     1105      } else if (IsFactor(a) && IsConstant(b)) {
     1106        var node0 = a as FactorVariableTreeNode;
     1107        var node1 = b as ConstantTreeNode;
     1108        return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Select(w => w * node1.Value));
     1109      } else if (IsBinFactor(a) && IsConstant(b)) {
     1110        var node0 = a as BinaryFactorVariableTreeNode;
     1111        var node1 = b as ConstantTreeNode;
     1112        return MakeBinFactor(node0.Symbol, node0.VariableName, node0.VariableValue, node0.Weight * node1.Value);
     1113      } else if (IsBinFactor(a) && IsFactor(b)) {
     1114        return MakeProduct(b, a);
     1115      } else if (IsFactor(a) && IsBinFactor(b) &&
     1116        ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
     1117        var node0 = a as FactorVariableTreeNode;
     1118        var node1 = b as BinaryFactorVariableTreeNode;
     1119        var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
     1120        var wi = Array.IndexOf(varValues, node1.VariableValue);
     1121        if (wi < 0) throw new ArgumentException();
     1122        return MakeBinFactor(node1.Symbol, node1.VariableName, node1.VariableValue, node1.Weight * node0.Weights[wi]);
    8601123      } else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
    8611124        // $ * 1.0 => $
    8621125        return a;
    863       } else if (IsConstant(b) && IsVariable(a)) {
     1126      } else if (IsConstant(b) && IsVariableBase(a)) {
    8641127        // multiply constants into variables weights
    865         ((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
     1128        ((VariableTreeNodeBase)a).Weight *= ((ConstantTreeNode)b).Value;
    8661129        return a;
    867       } else if (IsConstant(b) && IsAddition(a)) {
     1130      } else if (IsConstant(b) && IsAddition(a) ||
     1131          IsFactor(b) && IsAddition(a) ||
     1132          IsBinFactor(b) && IsAddition(a)) {
    8681133        // multiply constants into additions
    869         return a.Subtrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
     1134        return a.Subtrees.Select(x => MakeProduct(GetSimplifiedTree(x), GetSimplifiedTree(b))).Aggregate((c, d) => MakeSum(c, d));
    8701135      } else if (IsDivision(a) && IsDivision(b)) {
    8711136        // (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
     
    8881153      } else if (IsMultiplication(a)) {
    8891154        // a is already an multiplication => append b
    890         a.AddSubtree(b);
     1155        a.AddSubtree(GetSimplifiedTree(b));
    8911156        MergeVariablesAndConstantsInProduct(a);
    8921157        return a;
     
    8991164      }
    9001165    }
     1166
    9011167    #endregion
    9021168
    903 
    9041169    #region helper functions
     1170
    9051171    private bool ContainsVariableCondition(ISymbolicExpressionTreeNode node) {
    9061172      if (node.Symbol is VariableCondition) return true;
     
    9321198    }
    9331199
    934     private bool AreSameVariable(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
    935       var aLaggedVar = a as LaggedVariableTreeNode;
    936       var bLaggedVar = b as LaggedVariableTreeNode;
    937       if (aLaggedVar != null && bLaggedVar != null) {
    938         return aLaggedVar.VariableName == bLaggedVar.VariableName &&
    939           aLaggedVar.Lag == bLaggedVar.Lag;
    940       }
    941       var aVar = a as VariableTreeNode;
    942       var bVar = b as VariableTreeNode;
    943       if (aVar != null && bVar != null) {
    944         return aVar.VariableName == bVar.VariableName;
    945       }
    946       return false;
     1200    private bool AreSameTypeAndVariable(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
     1201      return GroupId((IVariableTreeNode)a) == GroupId((IVariableTreeNode)b);
    9471202    }
    9481203
     
    9511206      var subtrees = new List<ISymbolicExpressionTreeNode>(prod.Subtrees);
    9521207      while (prod.Subtrees.Any()) prod.RemoveSubtree(0);
    953       var groupedVarNodes = from node in subtrees.OfType<VariableTreeNode>()
    954                             let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
    955                             group node by node.VariableName + lag into g
     1208      var groupedVarNodes = from node in subtrees.OfType<IVariableTreeNode>()
     1209                            where node.SubtreeCount == 0
     1210                            group node by GroupId(node) into g
    9561211                            orderby g.Count()
    9571212                            select g;
    958       var constantProduct = (from node in subtrees.OfType<VariableTreeNode>()
     1213      var constantProduct = (from node in subtrees.OfType<VariableTreeNodeBase>()
    9591214                             select node.Weight)
    960                             .Concat(from node in subtrees.OfType<ConstantTreeNode>()
    961                                     select node.Value)
    962                             .DefaultIfEmpty(1.0)
    963                             .Aggregate((c1, c2) => c1 * c2);
     1215        .Concat(from node in subtrees.OfType<ConstantTreeNode>()
     1216                select node.Value)
     1217        .DefaultIfEmpty(1.0)
     1218        .Aggregate((c1, c2) => c1 * c2);
    9641219
    9651220      var unchangedSubtrees = from tree in subtrees
    966                               where !(tree is VariableTreeNode)
    967                               where !(tree is ConstantTreeNode)
     1221                              where tree.SubtreeCount > 0 || !(tree is IVariableTreeNode) && !(tree is ConstantTreeNode)
    9681222                              select tree;
    9691223
    9701224      foreach (var variableNodeGroup in groupedVarNodes) {
    971         var representative = variableNodeGroup.First();
    972         representative.Weight = 1.0;
    973         if (variableNodeGroup.Count() > 1) {
    974           var poly = mulSymbol.CreateTreeNode();
    975           for (int p = 0; p < variableNodeGroup.Count(); p++) {
    976             poly.AddSubtree((ISymbolicExpressionTreeNode)representative.Clone());
     1225        var firstNode = variableNodeGroup.First();
     1226        if (firstNode is VariableTreeNodeBase) {
     1227          var representative = (VariableTreeNodeBase)firstNode;
     1228          representative.Weight = 1.0;
     1229          if (variableNodeGroup.Count() > 1) {
     1230            var poly = mulSymbol.CreateTreeNode();
     1231            for (int p = 0; p < variableNodeGroup.Count(); p++) {
     1232              poly.AddSubtree((ISymbolicExpressionTreeNode)representative.Clone());
     1233            }
     1234            prod.AddSubtree(poly);
     1235          } else {
     1236            prod.AddSubtree(representative);
    9771237          }
    978           prod.AddSubtree(poly);
    979         } else {
     1238        } else if (firstNode is FactorVariableTreeNode) {
     1239          var representative = (FactorVariableTreeNode)firstNode;
     1240          foreach (var node in variableNodeGroup.Skip(1).Cast<FactorVariableTreeNode>()) {
     1241            for (int j = 0; j < representative.Weights.Length; j++) {
     1242              representative.Weights[j] *= node.Weights[j];
     1243            }
     1244          }
     1245          for (int j = 0; j < representative.Weights.Length; j++) {
     1246            representative.Weights[j] *= constantProduct;
     1247          }
     1248          constantProduct = 1.0;
     1249          // if the product already contains a factor it is not necessary to multiply a constant below
    9801250          prod.AddSubtree(representative);
    9811251        }
     
    9931263    /// <summary>
    9941264    /// x => x * -1
    995     /// Doesn't create new trees and manipulates x
     1265    /// Is only used in cases where it is not necessary to create new tree nodes. Manipulates x directly.
    9961266    /// </summary>
    9971267    /// <param name="x"></param>
     
    10001270      if (IsConstant(x)) {
    10011271        ((ConstantTreeNode)x).Value *= -1;
    1002       } else if (IsVariable(x)) {
    1003         var variableTree = (VariableTreeNode)x;
     1272      } else if (IsVariableBase(x)) {
     1273        var variableTree = (VariableTreeNodeBase)x;
    10041274        variableTree.Weight *= -1.0;
     1275      } else if (IsFactor(x)) {
     1276        var factorNode = (FactorVariableTreeNode)x;
     1277        for (int i = 0; i < factorNode.Weights.Length; i++) factorNode.Weights[i] *= -1;
     1278      } else if (IsBinFactor(x)) {
     1279        var factorNode = (BinaryFactorVariableTreeNode)x;
     1280        factorNode.Weight *= -1;
    10051281      } else if (IsAddition(x)) {
    10061282        // (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)       
     
    10241300    /// <summary>
    10251301    /// x => 1/x
    1026     /// Doesn't create new trees and manipulates x
     1302    /// Must create new tree nodes
    10271303    /// </summary>
    10281304    /// <param name="x"></param>
     
    10311307      if (IsConstant(x)) {
    10321308        return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
     1309      } else if (IsFactor(x)) {
     1310        var factorNode = (FactorVariableTreeNode)x;
     1311        return MakeFactor(factorNode.Symbol, factorNode.VariableName, factorNode.Weights.Select(w => 1.0 / w));
    10331312      } else if (IsDivision(x)) {
    10341313        return MakeFraction(x.GetSubtree(1), x.GetSubtree(0));
     
    10451324    }
    10461325
    1047     private ISymbolicExpressionTreeNode MakeVariable(double weight, string name) {
    1048       var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
     1326    private ISymbolicExpressionTreeNode MakeFactor(FactorVariable sy, string variableName, IEnumerable<double> weights) {
     1327      var tree = (FactorVariableTreeNode)sy.CreateTreeNode();
     1328      tree.VariableName = variableName;
     1329      tree.Weights = weights.ToArray();
     1330      return tree;
     1331    }
     1332    private ISymbolicExpressionTreeNode MakeBinFactor(BinaryFactorVariable sy, string variableName, string variableValue, double weight) {
     1333      var tree = (BinaryFactorVariableTreeNode)sy.CreateTreeNode();
     1334      tree.VariableName = variableName;
     1335      tree.VariableValue = variableValue;
    10491336      tree.Weight = weight;
    1050       tree.VariableName = name;
    10511337      return tree;
    10521338    }
     1339
     1340
    10531341    #endregion
    10541342  }
Note: See TracChangeset for help on using the changeset viewer.