Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ExprHashSymbolic.cs @ 15440

Last change on this file since 15440 was 15440, checked in by gkronber, 7 years ago

#2796 fixed hashing of expressions and unit tests for expression enumeration

File size: 9.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21using System;
22using System.Collections.Generic;
23using System.Diagnostics.Contracts;
24using System.Linq;
25
26namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
27  internal enum UnaryFunctionType { Log, Exp, Inv };
28  internal interface Factor { }
29  internal class SymbolFactor : Factor {
30    internal char symbolId;
31    public SymbolFactor(char symbolId) {
32      this.symbolId = symbolId;
33    }
34    public override int GetHashCode() {
35      return symbolId.GetHashCode();
36    }
37    public override bool Equals(object obj) {
38      SymbolFactor other = obj as SymbolFactor;
39      if (other == null) return false;
40      else return other.symbolId == this.symbolId;
41    }
42  }
43    internal class FunctionFactor : Factor {
44    internal UnaryFunctionType functionType;
45    internal Polynomial argument;
46
47    public FunctionFactor(UnaryFunctionType functionType, Polynomial argument) {
48      this.functionType = functionType;
49      this.argument = argument;
50    }
51
52    public override int GetHashCode() {
53      var h = functionType.GetHashCode();
54      return ((h<<5)+h) ^ argument.GetHashCode();
55    }
56    public override bool Equals(object obj) {
57      FunctionFactor other = obj as FunctionFactor;
58      if (other == null) return false;
59      return
60        other.functionType == this.functionType &&
61        other.argument == this.argument;
62    }
63  }
64
65  internal class Monomial {
66    internal List<Factor> factors = new List<Factor>();
67    public Monomial(params Factor[] factor) {
68      foreach (var f in factor) factors.Add(f);
69    }
70    public override int GetHashCode() {
71      return factors.OrderBy(ti => ti.GetHashCode()).Aggregate(0, (a, ti) => ((a  << 5) + a) ^ ti.GetHashCode());
72    }
73    public override bool Equals(object obj) {
74      Monomial other = obj as Monomial;
75      if (other == null) return false;
76      if (other.factors.Count != this.factors.Count) return false;
77      return factors.All(ti => other.factors.Contains(ti)) &&
78        other.factors.All(ti => factors.Contains(ti));
79    }
80
81    public static Monomial Product(Monomial a, Monomial b) {
82      var p = new Monomial();
83      var invFactorArg = new Polynomial();
84      var expFactorArg = new Polynomial();
85      var expFactor = new FunctionFactor(UnaryFunctionType.Exp, expFactorArg);
86
87      foreach (var aFactor in a.factors.Concat(b.factors)) {
88        // collect all exp and inv factors into one and simplify
89        var funcFactor = aFactor as FunctionFactor;
90        if (funcFactor != null && funcFactor.functionType == UnaryFunctionType.Exp) {
91          expFactorArg.Add(funcFactor.argument);
92        } else if (funcFactor != null && funcFactor.functionType == UnaryFunctionType.Inv) {
93          if (!invFactorArg.terms.Any()) invFactorArg.terms = new HashSet<Monomial>(funcFactor.argument.terms);
94          else invFactorArg.Mul(funcFactor.argument);
95        } else {
96          p.factors.Add(aFactor);
97        }
98      }
99      if (expFactorArg.terms.Any()) {
100        p.factors.Add(expFactor);
101      }
102      if (invFactorArg.terms.Any()) {
103        var invFactor = new FunctionFactor(UnaryFunctionType.Inv, invFactorArg);
104        p.factors.Add(invFactor);
105      }
106      return p;
107    }
108  }
109
110  internal class Polynomial {
111    internal HashSet<Monomial> terms = new HashSet<Monomial>();
112    public Polynomial(params Monomial[] term) {
113      foreach (var t in term) terms.Add(t);
114    }
115    public void Add(Polynomial other) {
116      this.terms.UnionWith(other.terms);
117    }
118    public void Mul(Polynomial other) {
119      var myTerms = terms;
120      var otherTerms = other.terms;
121      var newTerms = new HashSet<Monomial>();
122      foreach (var a in myTerms) {
123        foreach (var b in otherTerms) {
124          newTerms.Add(Monomial.Product(a, b));
125        }
126      }
127      terms = newTerms;
128    }
129    public override int GetHashCode() {
130      return terms.OrderBy(ti => ti.GetHashCode()).Aggregate(0, (a, ti) => ((a<<5)+a) ^ ti.GetHashCode());
131    }
132    public override bool Equals(object obj) {
133      Polynomial other = obj as Polynomial;
134      if (other == null) return false;
135      if (other.terms.Count != this.terms.Count) return false;
136      return terms.All(ti => other.terms.Contains(ti)) &&
137        other.terms.All(ti => terms.Contains(ti));
138    }
139  }
140
141
142  // calculates a hash-code for expressions.
143
144  public static class ExprHashSymbolic {
145
146
147    const int MaxStackSize = 100;
148    const int MaxVariables = 26;
149    private static SymbolFactor[] varSymbols;
150    private static SymbolFactor zero;
151    private static SymbolFactor one;
152
153
154    static ExprHashSymbolic() {
155      const string symbols = "abcdefghijklmnopqrstuvwxyz";
156
157      varSymbols = new SymbolFactor[MaxVariables];
158      for (int i = 0; i < MaxVariables; i++) {
159        varSymbols[i] = new SymbolFactor(symbols[i]);
160      }
161      zero = new SymbolFactor('0');
162      one = new SymbolFactor('1');
163    }
164
165    public static int GetHash(byte[] code, int nParams) {
166      return Eval(code, nParams);
167    }
168
169    private static int Eval(byte[] code, int nParams) {
170      // The hash code calculation already preserves commutativity, associativity and distributivity of operations.
171      // We assume that the structure contains numeric parameters
172      // x = c*x
173      // exp(x) = c*exp(c*x)
174      // log(x) = c*log(x+c)
175      // inv(x) = c/(x+c)
176      // Accordingly, each structure represents a class of functions.
177
178      // The following expressions should hash to the same value as they represent
179      // equivalent function classes
180      // - x1 + x1 is equivalent to x1
181      // - exp(x1) * exp(x1) is equivalent to exp(x1)
182      //
183      // The following experssions must not hash to the same value.
184      // - exp(x1) + exp(x1) is different from exp(x1)
185      // - log(x1) + log(x1) is different from log(x1)
186      // - 1/x1 + 1/x1 is different from 1/x1
187      // - TODO list all
188
189      // think about speed later (TODO)
190
191      var stack = new Polynomial[MaxStackSize];
192      int topOfStack = -1;
193      int pc = 0;
194      int nextParamIdx = -1;
195      OpCodes op;
196      short arg;
197      while (true) {
198        ReadNext(code, ref pc, out op, out arg);
199        switch (op) {
200          case OpCodes.Nop: break;
201          case OpCodes.LoadConst0: {
202              ++topOfStack;
203              stack[topOfStack] = new Polynomial(new Monomial(zero));
204              break;
205            }
206          case OpCodes.LoadConst1: {
207              ++topOfStack;
208              stack[topOfStack] = new Polynomial(new Monomial(one));
209              break;
210            }
211          case OpCodes.LoadParamN: {
212              ++topOfStack;
213              stack[topOfStack] = new Polynomial(new Monomial(one)); // TODO empty, or unique?
214              break;
215            }
216          case OpCodes.LoadVar: {
217              ++topOfStack;
218              stack[topOfStack] = new Polynomial(new Monomial(varSymbols[arg]));
219
220              break;
221            }
222          case OpCodes.Add: {
223              stack[topOfStack - 1].Add(stack[topOfStack]);
224              topOfStack--;
225              break;
226            }
227          case OpCodes.Mul: {
228              stack[topOfStack - 1].Mul(stack[topOfStack]);
229              topOfStack--;
230              break;
231            }
232          case OpCodes.Log: {
233              var v1 = stack[topOfStack];
234              stack[topOfStack] = new Polynomial(new Monomial(new FunctionFactor(UnaryFunctionType.Log, v1)));
235              break;
236            }
237          case OpCodes.Exp: {
238              var v1 = stack[topOfStack];
239              stack[topOfStack] = new Polynomial(new Monomial(new FunctionFactor(UnaryFunctionType.Exp, v1)));
240              break;
241            }
242          case OpCodes.Inv: {
243              var v1 = stack[topOfStack];
244              stack[topOfStack] = new Polynomial(new Monomial(new FunctionFactor(UnaryFunctionType.Inv, v1)));
245              break;
246            }
247          case OpCodes.Exit:
248            Contract.Assert(topOfStack == 0);
249            return stack[topOfStack].GetHashCode();
250          default: throw new InvalidOperationException();
251        }
252      }
253    }
254
255    private static void EvalTerms(HashSet<double> terms, double[] stack, ref int topOfStack) {
256      ++topOfStack;
257      stack[topOfStack] = terms.Sum();
258      terms.Clear();
259    }
260
261    private static void ReadNext(byte[] code, ref int pc, out OpCodes op, out short s) {
262      op = (OpCodes)Enum.ToObject(typeof(OpCodes), code[pc++]);
263      s = 0;
264      if (op == OpCodes.LoadVar) {
265        s = (short)((code[pc] << 8) | code[pc + 1]);
266        pc += 2;
267      }
268    }
269  }
270}
Note: See TracBrowser for help on using the repository browser.