Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/07/08 00:02:43 (16 years ago)
Author:
gkronber
Message:

worked on #139:

  • fixed display of trees in the gui
  • split list representation of tree into two lists code and data
  • implemented static evaluation for all predefined functions (except ProgrammableFunction)
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/ExperimentalFunctionsBaking/BakedFunctionTree.cs

    r208 r220  
    3030namespace HeuristicLab.Functions {
    3131  class BakedFunctionTree : ItemBase, IFunctionTree {
    32     private List<double> code;
    33     private static double nextFunctionSymbol = 10000;
    34     private static Dictionary<double, IFunction> symbolTable = new Dictionary<double, IFunction>();
    35     private static Dictionary<IFunction, double> reverseSymbolTable = new Dictionary<IFunction, double>();
    36     private static double additionSymbol = -1;
    37     private static double substractionSymbol = -1;
    38     private static double multiplicationSymbol = -1;
    39     private static double divisionSymbol = -1;
    40     private static double variableSymbol = -1;
    41     private static double constantSymbol = -1;
     32    private List<int> code;
     33    private List<double> data;
     34    private const int ADDITION = 10010;
     35    private const int AND = 10020;
     36    private const int AVERAGE = 10030;
     37    private const int CONSTANT = 10040;
     38    private const int COSINUS = 10050;
     39    private const int DIVISION = 10060;
     40    private const int EQU = 10070;
     41    private const int EXP = 10080;
     42    private const int GT = 10090;
     43    private const int IFTE = 10100;
     44    private const int LT = 10110;
     45    private const int LOG = 10120;
     46    private const int MULTIPLICATION = 10130;
     47    private const int NOT = 10140;
     48    private const int OR = 10150;
     49    private const int POWER = 10160;
     50    private const int SIGNUM = 10170;
     51    private const int SINUS = 10180;
     52    private const int SQRT = 10190;
     53    private const int SUBSTRACTION = 10200;
     54    private const int TANGENS = 10210;
     55    private const int VARIABLE = 10220;
     56    private const int XOR = 10230;
     57
     58    private static int nextFunctionSymbol = 10240;
     59    private static Dictionary<int, IFunction> symbolTable;
     60    private static Dictionary<IFunction, int> reverseSymbolTable;
     61    private static Dictionary<Type, int> staticTypes;
     62
     63    static BakedFunctionTree() {
     64      symbolTable = new Dictionary<int, IFunction>();
     65      reverseSymbolTable = new Dictionary<IFunction, int>();
     66      staticTypes = new Dictionary<Type, int>();
     67      staticTypes[typeof(Addition)] = ADDITION;
     68      staticTypes[typeof(And)] = AND;
     69      staticTypes[typeof(Average)] = AVERAGE;
     70      staticTypes[typeof(Constant)] = CONSTANT;
     71      staticTypes[typeof(Cosinus)] = COSINUS;
     72      staticTypes[typeof(Division)] = DIVISION;
     73      staticTypes[typeof(Equal)] = EQU;
     74      staticTypes[typeof(Exponential)] = EXP;
     75      staticTypes[typeof(GreaterThan)] = GT;
     76      staticTypes[typeof(IfThenElse)] = IFTE;
     77      staticTypes[typeof(LessThan)] = LT;
     78      staticTypes[typeof(Logarithm)] = LOG;
     79      staticTypes[typeof(Multiplication)] = MULTIPLICATION;
     80      staticTypes[typeof(Not)] = NOT;
     81      staticTypes[typeof(Or)] = OR;
     82      staticTypes[typeof(Power)] = POWER;
     83      staticTypes[typeof(Signum)] = SIGNUM;
     84      staticTypes[typeof(Sinus)] = SINUS;
     85      staticTypes[typeof(Sqrt)] = SQRT;
     86      staticTypes[typeof(Substraction)] = SUBSTRACTION;
     87      staticTypes[typeof(Tangens)] = TANGENS;
     88      staticTypes[typeof(Variable)] = VARIABLE;
     89      staticTypes[typeof(Xor)] = XOR;
     90    }
    4291
    4392    internal BakedFunctionTree() {
    44       code = new List<double>();
     93      code = new List<int>();
     94      data = new List<double>();
    4595    }
    4696
     
    65115      code.Add(0);
    66116      code.Add(MapFunction(tree.Function));
    67       code.Add((byte)tree.LocalVariables.Count);
     117      code.Add(tree.LocalVariables.Count);
    68118      foreach(IVariable variable in tree.LocalVariables) {
    69119        IItem value = variable.Value;
    70         code.Add(GetDoubleValue(value));
     120        data.Add(GetDoubleValue(value));
    71121      }
    72122      foreach(IFunctionTree subTree in tree.SubTrees) {
     
    87137    }
    88138
    89     private double MapFunction(IFunction function) {
     139    private int MapFunction(IFunction function) {
    90140      if(!reverseSymbolTable.ContainsKey(function)) {
    91         reverseSymbolTable[function] = nextFunctionSymbol;
    92         symbolTable[nextFunctionSymbol] = function;
    93         if(function is Variable) {
    94           variableSymbol = nextFunctionSymbol;
    95         } else if(function is Constant) {
    96           constantSymbol = nextFunctionSymbol;
    97         } else if(function is Addition) {
    98           additionSymbol = nextFunctionSymbol;
    99         } else if(function is Substraction) {
    100           substractionSymbol = nextFunctionSymbol;
    101         } else if(function is Multiplication) {
    102           multiplicationSymbol = nextFunctionSymbol;
    103         } else if(function is Division) {
    104           divisionSymbol = nextFunctionSymbol;
    105         } else throw new NotSupportedException("Unsupported function " + function);
    106 
    107         nextFunctionSymbol++;
     141        int curFunctionSymbol;
     142        if(staticTypes.ContainsKey(function.GetType())) curFunctionSymbol = staticTypes[function.GetType()];
     143        else {
     144          curFunctionSymbol = nextFunctionSymbol;
     145          nextFunctionSymbol++;
     146        }
     147        reverseSymbolTable[function] = curFunctionSymbol;
     148        symbolTable[curFunctionSymbol] = function;
    108149      }
    109150      return reverseSymbolTable[function];
    110151    }
    111152
    112     private int BranchLength(int branchRoot) {
    113       double arity = code[branchRoot];
    114       int nLocalVariables = (int)code[branchRoot + 2];
    115       int len = 3 + nLocalVariables;
    116       int subBranchStart = branchRoot + len;
     153    private void BranchLength(int branchRoot, out int codeLength, out int dataLength) {
     154      int arity = code[branchRoot];
     155      int nLocalVariables = code[branchRoot + 2];
     156      codeLength = 3;
     157      dataLength = nLocalVariables;
     158      int subBranchStart = branchRoot + codeLength;
    117159      for(int i = 0; i < arity; i++) {
    118         int branchLen = BranchLength(subBranchStart);
    119         len += branchLen;
    120         subBranchStart += branchLen;
    121       }
    122       return len;
     160        int branchCodeLength;
     161        int branchDataLength;
     162        BranchLength(subBranchStart, out branchCodeLength, out branchDataLength);
     163        subBranchStart += branchCodeLength;
     164        codeLength += branchCodeLength;
     165        dataLength += branchDataLength;
     166      }
    123167    }
    124168
     
    130174          subTree.FlattenTrees();
    131175          code.AddRange(subTree.code);
     176          data.AddRange(subTree.data);
    132177        }
    133178        treesExpanded = false;
     
    139184      if(variablesExpanded) {
    140185        code[2] = variables.Count;
    141         int localVariableIndex = 3;
    142186        foreach(IVariable variable in variables) {
    143           code.Insert(localVariableIndex, GetDoubleValue(variable.Value));
    144           localVariableIndex++;
     187          data.Add(GetDoubleValue(variable.Value));
    145188        }
    146189        variablesExpanded = false;
     
    155198        if(!treesExpanded) {
    156199          subTrees = new List<IFunctionTree>();
    157           double arity = code[0];
    158           int nLocalVariables = (int)code[2];
    159           int branchIndex = 3 + nLocalVariables;
     200          int arity = code[0];
     201          int nLocalVariables = code[2];
     202          int branchIndex = 3;
     203          int dataIndex = nLocalVariables; // skip my local variables to reach the local variables of the first branch
    160204          for(int i = 0; i < arity; i++) {
    161205            BakedFunctionTree subTree = new BakedFunctionTree();
    162             int branchLen = BranchLength(branchIndex);
    163             subTree.code = code.GetRange(branchIndex, branchLen);
    164             branchIndex += branchLen;
     206            int codeLength;
     207            int dataLength;
     208            BranchLength(branchIndex, out codeLength, out dataLength);
     209            subTree.code = code.GetRange(branchIndex, codeLength);
     210            subTree.data = data.GetRange(dataIndex, dataLength);
     211            branchIndex += codeLength;
     212            dataIndex += dataLength;
    165213            subTrees.Add(subTree);
    166214          }
    167215          treesExpanded = true;
    168           code.RemoveRange(3 + nLocalVariables, code.Count - (3 + nLocalVariables));
     216          code.RemoveRange(3, code.Count - 3);
    169217          code[0] = 0;
     218          data.RemoveRange(nLocalVariables, data.Count - nLocalVariables);
    170219        }
    171220        return subTrees;
     
    180229          variables = new List<IVariable>();
    181230          IFunction function = symbolTable[code[1]];
    182           int localVariableIndex = 3;
     231          int localVariableIndex = 0;
    183232          foreach(IVariableInfo variableInfo in function.VariableInfos) {
    184233            if(variableInfo.Local) {
     
    186235              IItem value = clone.Value;
    187236              if(value is ConstrainedDoubleData) {
    188                 ((ConstrainedDoubleData)value).Data = code[localVariableIndex];
     237                ((ConstrainedDoubleData)value).Data = data[localVariableIndex];
    189238              } else if(value is ConstrainedIntData) {
    190                 ((ConstrainedIntData)value).Data = (int)code[localVariableIndex];
     239                ((ConstrainedIntData)value).Data = (int)data[localVariableIndex];
    191240              } else if(value is DoubleData) {
    192                 ((DoubleData)value).Data = code[localVariableIndex];
     241                ((DoubleData)value).Data = data[localVariableIndex];
    193242              } else if(value is IntData) {
    194                 ((IntData)value).Data = (int)code[localVariableIndex];
     243                ((IntData)value).Data = (int)data[localVariableIndex];
    195244              } else throw new NotSupportedException("Invalid local variable type for GP.");
    196245              variables.Add(clone);
     
    200249          variablesExpanded = true;
    201250          code[2] = 0;
    202           code.RemoveRange(3, variables.Count);
     251          data.RemoveRange(0, variables.Count);
    203252        }
    204253        return variables;
     
    211260
    212261    public IVariable GetLocalVariable(string name) {
    213       throw new NotImplementedException();
     262      foreach(IVariable var in LocalVariables) {
     263        if(var.Name == name) return var;
     264      }
     265      return null;
    214266    }
    215267
    216268    public void AddVariable(IVariable variable) {
    217       throw new NotImplementedException();
     269      throw new NotSupportedException();
    218270    }
    219271
    220272    public void RemoveVariable(string name) {
    221       throw new NotImplementedException();
     273      throw new NotSupportedException();
    222274    }
    223275
    224276    public void AddSubTree(IFunctionTree tree) {
    225       if(treesExpanded) {
    226         subTrees.Add(tree);
    227       } else {
    228         //code.AddRange(((BakedFunctionTree)tree).code);
    229         //code[0] = code[0] + 1;
    230         throw new NotImplementedException();
    231       }
     277      if(!treesExpanded) throw new InvalidOperationException();
     278      subTrees.Add(tree);
    232279    }
    233280
    234281    public void InsertSubTree(int index, IFunctionTree tree) {
    235       if(treesExpanded) {
    236         subTrees.Insert(index, tree);
    237       } else {
    238         //byte nLocalVariables = code[2];
    239         //// skip branches
    240         //int currentBranchIndex = 3 + nLocalVariables;
    241         //for(int i = 0; i < index; i++) {
    242         //  int branchLength = BranchLength(currentBranchIndex);
    243         //  currentBranchIndex += branchLength;
    244         //}
    245 
    246         //code.InsertRange(currentBranchIndex, ((BakedFunctionTree)tree).code);
    247         //code[0] = code[0] + 1;
    248 
    249         throw new NotImplementedException();
    250       }
     282      if(!treesExpanded) throw new InvalidOperationException();
     283      subTrees.Insert(index, tree);
    251284    }
    252285
    253286    public void RemoveSubTree(int index) {
    254       if(treesExpanded) {
    255         subTrees.RemoveAt(index);
    256       } else {
    257         //int nLocalVariables = (int)code[2];
    258         //// skip branches
    259         //int currentBranchIndex = 3 + nLocalVariables;
    260         //for(int i = 0; i < index; i++) {
    261         //  int branchLength = BranchLength(currentBranchIndex);
    262         //  currentBranchIndex += branchLength;
    263         //}
    264         //int deletedBranchLength = BranchLength(currentBranchIndex);
    265         //code.RemoveRange(currentBranchIndex, deletedBranchLength);
    266         //code[0] = code[0] - 1;
    267         throw new NotImplementedException();
    268       }
     287      // sanity check
     288      if(!treesExpanded) throw new InvalidOperationException();
     289      subTrees.RemoveAt(index);
    269290    }
    270291
    271292    private int PC;
    272     private double[] codeArr;
     293    private int DP;
     294    private int[] codeArr;
     295    private double[] dataArr;
     296    private Dataset dataset;
     297    private int sampleIndex;
    273298    public double Evaluate(Dataset dataset, int sampleIndex) {
    274299      PC = 0;
     300      DP = 0;
    275301      FlattenVariables();
    276302      FlattenTrees();
    277303      if(codeArr == null) {
    278         codeArr = new double[code.Count];
     304        codeArr = new int[code.Count];
     305        dataArr = new double[data.Count];
    279306        code.CopyTo(codeArr);
    280       }
    281       return EvaluateBakedCode(dataset, sampleIndex);
    282     }
    283 
    284     private double EvaluateBakedCode(Dataset dataset, int sampleIndex) {
    285       double arity = codeArr[PC++];
    286       double functionSymbol = codeArr[PC++];
    287       double nLocalVariables = codeArr[PC++];
    288       if(functionSymbol == variableSymbol) {
    289         int var = (int)codeArr[PC++];
    290         double weight = codeArr[PC++];
    291         int offset = (int)codeArr[PC++];
    292         return weight * dataset.GetValue(sampleIndex+offset, var);
    293       } else if(functionSymbol == constantSymbol) {
    294         double value = codeArr[PC++];
    295         return value;
    296       } else if(functionSymbol == additionSymbol) {
    297         double sum = 0.0;
    298         for(int i = 0; i < arity; i++) {
    299           sum += EvaluateBakedCode(dataset, sampleIndex);
    300         }
    301         return sum;
    302       } else if(functionSymbol == substractionSymbol) {
    303         if(arity == 1) {
    304           return -EvaluateBakedCode(dataset, sampleIndex);
    305         } else {
    306           double result = EvaluateBakedCode(dataset, sampleIndex);
    307           for(int i = 1; i < arity; i++) {
    308             result -= EvaluateBakedCode(dataset, sampleIndex);
    309           }
    310           return result;
    311         }
    312       } else if(functionSymbol == multiplicationSymbol) {
    313         double result = 1.0;
    314         for(int i = 0; i < arity; i++) {
    315           result *= EvaluateBakedCode(dataset, sampleIndex);
    316         }
    317         return result;
    318       } else if(functionSymbol == divisionSymbol) {
    319         if(arity == 1) {
    320           double divisor = EvaluateBakedCode(dataset, sampleIndex);
    321           if(divisor == 0) return 0;
    322           else return 1.0 / divisor;
    323         } else {
    324           double result = EvaluateBakedCode(dataset, sampleIndex);
    325           for(int i = 1; i < arity; i++) {
    326             double divisor = EvaluateBakedCode(dataset, sampleIndex);
    327             if(divisor == 0) result = 0;
    328             else result /= divisor;
    329           }
    330           return result;
    331         }
    332       } else { throw new NotSupportedException(); }
     307        data.CopyTo(dataArr);
     308      }
     309      this.sampleIndex = sampleIndex;
     310      this.dataset = dataset;
     311      return EvaluateBakedCode();
     312    }
     313
     314    private double EvaluateBakedCode() {
     315      int arity = codeArr[PC++];
     316      int functionSymbol = codeArr[PC++];
     317      int nLocalVariables = codeArr[PC++];
     318      switch(functionSymbol) {
     319        case VARIABLE: {
     320            int var = (int)dataArr[DP++];
     321            double weight = dataArr[DP++];
     322            int offset = (int)dataArr[DP++];
     323            return weight * dataset.GetValue(sampleIndex + offset, var);
     324          }
     325        case CONSTANT: {
     326            double value = dataArr[DP++];
     327            return value;
     328          }
     329        case MULTIPLICATION: {
     330            double result = 1.0;
     331            for(int i = 0; i < arity; i++) {
     332              result *= EvaluateBakedCode();
     333            }
     334            return result;
     335          }
     336        case ADDITION: {
     337            double sum = 0.0;
     338            for(int i = 0; i < arity; i++) {
     339              sum += EvaluateBakedCode();
     340            }
     341            return sum;
     342          }
     343        case SUBSTRACTION: {
     344            if(arity == 1) {
     345              return -EvaluateBakedCode();
     346            } else {
     347              double result = EvaluateBakedCode();
     348              for(int i = 1; i < arity; i++) {
     349                result -= EvaluateBakedCode();
     350              }
     351              return result;
     352            }
     353          }
     354        case DIVISION: {
     355            if(arity == 1) {
     356              double divisor = EvaluateBakedCode();
     357              if(divisor == 0) return 0;
     358              else return 1.0 / divisor;
     359            } else {
     360              double result = EvaluateBakedCode();
     361              for(int i = 1; i < arity; i++) {
     362                double divisor = EvaluateBakedCode();
     363                if(divisor == 0) result = 0;
     364                else result /= divisor;
     365              }
     366              return result;
     367            }
     368          }
     369        case AVERAGE: {
     370            double sum = 0.0;
     371            for(int i = 0; i < arity; i++) {
     372              sum += EvaluateBakedCode();
     373            }
     374            return sum / arity;
     375          }
     376        case COSINUS: {
     377            return Math.Cos(EvaluateBakedCode());
     378          }
     379        case SINUS: {
     380            return Math.Sin(EvaluateBakedCode());
     381          }
     382        case EXP: {
     383            return Math.Exp(EvaluateBakedCode());
     384          }
     385        case LOG: {
     386            return Math.Log(EvaluateBakedCode());
     387          }
     388        case POWER: {
     389            double x = EvaluateBakedCode();
     390            double p = EvaluateBakedCode();
     391            return Math.Pow(x, p);
     392          }
     393        case SIGNUM: {
     394            // protected signum
     395            double value = EvaluateBakedCode();
     396            if(value < 0) return -1;
     397            if(value > 0) return 1;
     398            return 0;
     399          }
     400        case SQRT: {
     401            return Math.Sqrt(EvaluateBakedCode());
     402          }
     403        case TANGENS: {
     404            return Math.Tan(EvaluateBakedCode());
     405          }
     406        case AND: {
     407            double result = 1.0;
     408            // have to evaluate all sub-trees, skipping would probably not lead to a big gain because
     409            // we have to iterate over the linear structure anyway
     410            for(int i = 0; i < arity; i++) {
     411              double x = Math.Round(EvaluateBakedCode());
     412              if(x == 0) result *= 0;
     413              else if(x == 1.0) result *= 1.0;
     414              else result *= double.NaN;
     415            }
     416            return result;
     417          }
     418        case EQU: {
     419            double x = EvaluateBakedCode();
     420            double y = EvaluateBakedCode();
     421            if(x == y) return 1.0; else return 0.0;
     422          }
     423        case GT: {
     424            double x = EvaluateBakedCode();
     425            double y = EvaluateBakedCode();
     426            if(x > y) return 1.0;
     427            else return 0.0;
     428          }
     429        case IFTE: {
     430            double condition = Math.Round(EvaluateBakedCode());
     431            double x = EvaluateBakedCode();
     432            double y = EvaluateBakedCode();
     433            if(condition < .5) return x;
     434            else if(condition >= .5) return y;
     435            else return double.NaN;
     436          }
     437        case LT: {
     438            double x = EvaluateBakedCode();
     439            double y = EvaluateBakedCode();
     440            if(x < y) return 1.0;
     441            else return 0.0;
     442          }
     443        case NOT: {
     444            double result = Math.Round(EvaluateBakedCode());
     445            if(result == 0.0) return 1.0;
     446            else if(result == 1.0) return 0.0;
     447            else return double.NaN;
     448          }
     449        case OR: {
     450            double result = 0.0; // default is false
     451            for(int i = 0; i < arity; i++) {
     452              double x = Math.Round(EvaluateBakedCode());
     453              if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true
     454              else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN)
     455            }
     456            return result;
     457          }
     458        case XOR: {
     459            double x = Math.Round(EvaluateBakedCode());
     460            double y = Math.Round(EvaluateBakedCode());
     461            if(x == 0.0 && y == 0.0) return 0.0;
     462            if(x == 1.0 && y == 0.0) return 1.0;
     463            if(x == 0.0 && y == 1.0) return 1.0;
     464            if(x == 1.0 && y == 1.0) return 0.0;
     465            return double.NaN;
     466          }
     467        default: {
     468            IFunction function = symbolTable[functionSymbol];
     469            double[] args = new double[nLocalVariables + arity];
     470            for(int i = 0; i < nLocalVariables; i++) {
     471              args[i] = dataArr[DP++];
     472            }
     473            for(int j = 0; j < arity; j++) {
     474              args[nLocalVariables + j] = EvaluateBakedCode();
     475            }
     476            return function.Apply(dataset, sampleIndex, args);
     477          }
     478      }
    333479    }
    334480
     
    343489    public override object Clone(IDictionary<Guid, object> clonedObjects) {
    344490      BakedFunctionTree clone = new BakedFunctionTree();
    345       if(treesExpanded || variablesExpanded) throw new InvalidOperationException();
    346       //if(treesExpanded) FlattenTrees();
    347       //if(variablesExpanded) FlattenVariables();
    348       clone.code = new List<double>(code);
     491      if(treesExpanded || variablesExpanded) throw new InvalidOperationException(); // sanity check
     492      clone.code.AddRange(code);
     493      clone.data.AddRange(data);
    349494      return clone;
     495    }
     496
     497    public override IView CreateView() {
     498      return new FunctionTreeView(this);
    350499    }
    351500  }
Note: See TracChangeset for help on using the changeset viewer.