source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/go-code/go-pge/pge/pge_search.go @ 16080

Last change on this file since 16080 was 16080, checked in by hmaislin, 13 months ago

#2929 initial commit of working PGE version

File size: 15.8 KB
Line 
1package pge
2
3import (
4  "bufio"
5  "fmt"
6  "io/ioutil"
7  "log"
8  "math"
9  "os"
10  "strconv"
11  "strings"
12
13  levmar "github.com/verdverm/go-levmar"
14  config "github.com/verdverm/go-pge/config"
15  probs "github.com/verdverm/go-pge/problems"
16  expr "github.com/verdverm/go-symexpr"
17)
18
19type pgeConfig struct {
20  // search params
21  maxGen        int
22  pgeRptEpoch   int
23  pgeRptCount   int
24  pgeArchiveCap int
25
26  simprules expr.SimpRules
27  treecfg   *probs.TreeParams
28
29  // PGE specific options
30  peelCnt     int
31  sortType    probs.SortType
32  zeroEpsilon float64
33
34  initMethod string
35  growMethod string
36
37  evalrCount int
38}
39
40func pgeConfigParser(field, value string, config interface{}) (err error) {
41
42  PC := config.(*pgeConfig)
43
44  switch strings.ToUpper(field) {
45  case "MAXGEN":
46    PC.maxGen, err = strconv.Atoi(value)
47  case "PGERPTEPOCH":
48    PC.pgeRptEpoch, err = strconv.Atoi(value)
49  case "PGERPTCOUNT":
50    PC.pgeRptCount, err = strconv.Atoi(value)
51  case "PGEARCHIVECAP":
52    PC.pgeArchiveCap, err = strconv.Atoi(value)
53
54  case "PEELCOUNT":
55    PC.peelCnt, err = strconv.Atoi(value)
56
57  case "EVALRCOUNT":
58    PC.evalrCount, err = strconv.Atoi(value)
59
60  case "SORTTYPE":
61    switch strings.ToLower(value) {
62    case "paretotrainerror":
63      PC.sortType = probs.PESORT_PARETO_TRN_ERR
64    case "paretotesterror":
65      PC.sortType = probs.PESORT_PARETO_TST_ERR
66
67    default:
68      log.Printf("PGE Config Not Implemented: %s, %s\n\n", field, value)
69    }
70
71  case "ZEROEPSILON":
72    PC.zeroEpsilon, err = strconv.ParseFloat(value, 64)
73
74  default:
75    // check augillary parsable structures [only TreeParams for now]
76    if PC.treecfg == nil {
77      PC.treecfg = new(probs.TreeParams)
78    }
79    found, ferr := probs.ParseTreeParams(field, value, PC.treecfg)
80    if ferr != nil {
81      log.Fatalf("error parsing PGE - treecfg Config\n")
82      return ferr
83    }
84    if !found {
85      log.Printf("PGE Config Not Implemented: %s, %s\n\n", field, value)
86    }
87
88  }
89  return
90}
91
92type PgeSearch struct {
93  id   int
94  cnfg pgeConfig
95  prob *probs.ExprProblem
96  iter int
97  stop bool
98
99  // comm up
100  commup *probs.ExprProblemComm
101
102  // comm down
103
104  // best exprs
105  Best *probs.ReportQueue
106
107  // training data in C format
108  c_input  []levmar.C_double
109  c_ygiven []levmar.C_double
110
111  // logs
112  logDir     string
113  mainLog    *log.Logger
114  mainLogBuf *bufio.Writer
115  eqnsLog    *log.Logger
116  eqnsLogBuf *bufio.Writer
117  errLog     *log.Logger
118  errLogBuf  *bufio.Writer
119
120  fitnessLog    *log.Logger
121  fitnessLogBuf *bufio.Writer
122  ipreLog       *log.Logger
123  ipreLogBuf    *bufio.Writer
124
125  // equations visited
126  Trie  *IpreNode
127  Queue *probs.ReportQueue
128
129  // eval channels
130  eval_in  chan expr.Expr
131  eval_out chan *probs.ExprReport
132
133  // genStuff
134  GenRoots   []expr.Expr
135  GenLeafs   []expr.Expr
136  GenNodes   []expr.Expr
137  GenNonTrig []expr.Expr
138
139  // FFXish stuff
140  ffxBases []expr.Expr
141
142  // statistics
143  neqns    int
144  ipreCnt  int
145  maxSize  int
146  maxScore int
147  minError float64
148}
149
150func (PS *PgeSearch) GetMaxIter() int {
151  return PS.cnfg.maxGen
152}
153func (PS *PgeSearch) SetMaxIter(iter int) {
154  PS.cnfg.maxGen = iter
155}
156func (PS *PgeSearch) SetPeelCount(cnt int) {
157  PS.cnfg.peelCnt = cnt
158}
159func (PS *PgeSearch) SetInitMethod(init string) {
160  PS.cnfg.initMethod = init
161}
162func (PS *PgeSearch) SetGrowMethod(grow string) {
163  PS.cnfg.growMethod = grow
164}
165func (PS *PgeSearch) SetEvalrCount(cnt int) {
166  PS.cnfg.evalrCount = cnt
167}
168
169func (PS *PgeSearch) ParseConfig(filename string) {
170  fmt.Printf("Parsing PGE Config: %s\n", filename)
171  data, err := ioutil.ReadFile(filename)
172  if err != nil {
173    log.Fatal(err)
174  }
175  err = config.ParseConfig(data, pgeConfigParser, &PS.cnfg)
176  if err != nil {
177    log.Fatal(err)
178  }
179}
180
181func (PS *PgeSearch) Init(done chan int, prob *probs.ExprProblem, logdir string, input interface{}) {
182  fmt.Printf("Init'n PGE\n")
183  // setup data
184
185  // open logs
186  PS.initLogs(logdir)
187
188  // copy in common config options
189  PS.prob = prob
190  if PS.cnfg.treecfg == nil {
191    PS.cnfg.treecfg = PS.prob.TreeCfg.Clone()
192  }
193  srules := expr.DefaultRules()
194  srules.ConvertConsts = true
195  PS.cnfg.simprules = srules
196
197  fmt.Println("Roots:   ", PS.cnfg.treecfg.RootsS)
198  fmt.Println("Nodes:   ", PS.cnfg.treecfg.NodesS)
199  fmt.Println("Leafs:   ", PS.cnfg.treecfg.LeafsS)
200  fmt.Println("NonTrig: ", PS.cnfg.treecfg.NonTrigS)
201
202  PS.GenRoots = make([]expr.Expr, len(PS.cnfg.treecfg.Roots))
203  for i := 0; i < len(PS.GenRoots); i++ {
204    PS.GenRoots[i] = PS.cnfg.treecfg.Roots[i].Clone()
205  }
206  PS.GenNodes = make([]expr.Expr, len(PS.cnfg.treecfg.Nodes))
207  for i := 0; i < len(PS.GenNodes); i++ {
208    PS.GenNodes[i] = PS.cnfg.treecfg.Nodes[i].Clone()
209  }
210  PS.GenNonTrig = make([]expr.Expr, len(PS.cnfg.treecfg.NonTrig))
211  for i := 0; i < len(PS.GenNonTrig); i++ {
212    PS.GenNonTrig[i] = PS.cnfg.treecfg.NonTrig[i].Clone()
213  }
214
215  PS.GenLeafs = make([]expr.Expr, 0)
216  for _, t := range PS.cnfg.treecfg.LeafsT {
217    switch t {
218    case expr.TIME:
219      PS.GenLeafs = append(PS.GenLeafs, expr.NewTime())
220
221    case expr.VAR:
222      fmt.Println("Use Vars: ", PS.cnfg.treecfg.UsableVars)
223      for _, i := range PS.cnfg.treecfg.UsableVars {
224        PS.GenLeafs = append(PS.GenLeafs, expr.NewVar(i))
225      }
226
227    case expr.SYSTEM:
228      for i := 0; i < PS.prob.Train[0].NumSys(); i++ {
229        PS.GenLeafs = append(PS.GenLeafs, expr.NewSystem(i))
230      }
231
232    }
233  }
234  /*** FIX ME
235  PS.GenLeafs = make([]expr.Expr, len(PS.cnfg.treecfg.Leafs))
236  for i := 0; i < len(PS.GenLeafs); i++ {
237    PS.GenLeafs[i] = PS.cnfg.treecfg.Leafs[i].Clone()
238  }
239  ***/
240
241  fmt.Println("Roots:   ", PS.GenRoots)
242  fmt.Println("Nodes:   ", PS.GenNodes)
243  fmt.Println("Leafs:   ", PS.GenLeafs)
244  fmt.Println("NonTrig: ", PS.GenNonTrig)
245
246  // setup communication struct
247  PS.commup = input.(*probs.ExprProblemComm)
248
249  // initialize bbq
250  PS.Trie = new(IpreNode)
251  PS.Trie.val = -1
252  PS.Trie.next = make(map[int]*IpreNode)
253
254  PS.Best = probs.NewReportQueue()
255  PS.Best.SetSort(probs.GPSORT_PARETO_TST_ERR)
256  PS.Queue = PS.GenInitExpr()
257  PS.Queue.SetSort(probs.PESORT_PARETO_TST_ERR)
258
259  PS.neqns = PS.Queue.Len()
260
261  PS.minError = math.Inf(1)
262
263  PS.eval_in = make(chan expr.Expr, 4048)
264  PS.eval_out = make(chan *probs.ExprReport, 4048)
265
266  for i := 0; i < PS.cnfg.evalrCount; i++ {
267    go PS.Evaluate()
268  }
269}
270
271func (PS *PgeSearch) Evaluate() {
272
273  for !PS.stop {
274    e := <-PS.eval_in
275    if e == nil {
276      continue
277    }
278    PS.eval_out <- RegressExpr(e, PS.prob)
279  }
280
281}
282
283func (PS *PgeSearch) Run() {
284  fmt.Printf("Running PGE\n")
285
286  PS.loop()
287
288  fmt.Println("PGE exitting")
289
290  PS.Clean()
291  PS.commup.Cmds <- -1
292}
293
294func (PS *PgeSearch) loop() {
295
296  PS.checkMessages()
297  for !PS.stop {
298
299    fmt.Println("in: PS.step() ", PS.iter)
300    PS.step()
301
302    // if PS.iter%PS.cnfg.pgeRptEpoch == 0 {
303    PS.reportExpr()
304    // }
305
306    // report current iteration
307    PS.commup.Gen <- [2]int{PS.id, PS.iter}
308    PS.iter++
309
310    PS.Clean()
311
312    PS.checkMessages()
313
314  }
315
316  // done expanding, pull the rest of the regressed solutions from the queue
317  p := 0
318  for PS.Queue.Len() > 0 {
319    e := PS.Queue.Pop().(*probs.ExprReport)
320
321    bPush := true
322    if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon {
323      // fmt.Println("No Best Push")
324      bPush = false
325    }
326
327    if bPush {
328      // fmt.Printf("pop/push(%d,%d): %v\n", p, PS.Best.Len(), e.Expr())
329      PS.Best.Push(e)
330      p++
331    }
332
333    if e.TestScore() > PS.maxScore {
334      PS.maxScore = e.TestScore()
335    }
336    if e.TestError() < PS.minError {
337      PS.minError = e.TestError()
338      fmt.Printf("EXITING New Min Error:  %v\n", e)
339    }
340    if e.Size() > PS.maxSize {
341      PS.maxSize = e.Size()
342    }
343  }
344
345  fmt.Println("PGE sending last report")
346  PS.reportExpr()
347
348}
349
350func (PS *PgeSearch) step() {
351
352  loop := 0
353  eval_cnt := 0 // for channeled eval
354
355  es := PS.peel()
356
357  ex := PS.expandPeeled(es)
358
359  for cnt := range ex {
360    E := ex[cnt]
361
362    if E == nil {
363      continue
364    }
365
366    for _, e := range E {
367      if e == nil {
368        continue
369      }
370      if !PS.cnfg.treecfg.CheckExpr(e) {
371        continue
372      }
373
374      // check ipre_trie
375      serial := make([]int, 0, 64)
376      serial = e.Serial(serial)
377      ins := PS.Trie.InsertSerial(serial)
378      if !ins {
379        continue
380      }
381
382      // for serial eval
383      // re := RegressExpr(e, PS.prob)
384
385      // start channeled eval
386      PS.eval_in <- e
387      eval_cnt++
388    }
389  }
390  for i := 0; i < eval_cnt; i++ {
391    re := <-PS.eval_out
392    // end channeled eval
393
394    // check for NaN/Inf in re.error  and  if so, skip
395    if math.IsNaN(re.TestError()) || math.IsInf(re.TestError(), 0) {
396      // fmt.Printf("Bad Error\n%v\n", re)
397      continue
398    }
399
400    if re.TestError() < PS.minError {
401      PS.minError = re.TestError()
402    }
403
404    // check for coeff == 0
405    doIns := true
406    for _, c := range re.Coeff() {
407      // i > 0 for free coeff
408      if math.Abs(c) < PS.cnfg.zeroEpsilon {
409        doIns = false
410        break
411      }
412    }
413
414    if doIns {
415      re.SetProcID(PS.id)
416      re.SetIterID(PS.iter)
417      re.SetUnitID(loop)
418      re.SetUniqID(PS.neqns)
419      loop++
420      PS.neqns++
421      // fmt.Printf("Queue.Push(): %v\n%v\n\n", re.Expr(), serial)
422      // fmt.Printf("Queue.Push(): %v\n", re)
423      // fmt.Printf("Queue.Push(): %v\n", re.Expr())
424
425      PS.Queue.Push(re)
426
427    }
428  }
429  // } // for sequential eval
430  PS.Queue.Sort()
431
432}
433
434func (PS *PgeSearch) peel() []*probs.ExprReport {
435  es := make([]*probs.ExprReport, PS.cnfg.peelCnt)
436  for p := 0; p < PS.cnfg.peelCnt && PS.Queue.Len() > 0; p++ {
437
438    e := PS.Queue.Pop().(*probs.ExprReport)
439
440    bPush := true
441    if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon {
442      fmt.Println("No Best Push")
443      p--
444      continue
445    }
446
447    if bPush {
448      fmt.Printf("pop/push(%d,%d): %v\n", p, PS.Best.Len(), e.Expr())
449      PS.Best.Push(e)
450    }
451
452    es[p] = e
453
454    if e.TestScore() > PS.maxScore {
455      PS.maxScore = e.TestScore()
456    }
457    if e.TestError() < PS.minError {
458      PS.minError = e.TestError()
459      fmt.Printf("Best New Min Error:  %v\n", e)
460    }
461    if e.Size() > PS.maxSize {
462      PS.maxSize = e.Size()
463    }
464
465  }
466  return es
467}
468
469func (PS *PgeSearch) expandPeeled(es []*probs.ExprReport) [][]expr.Expr {
470  eqns := make([][]expr.Expr, PS.cnfg.peelCnt)
471  for p := 0; p < PS.cnfg.peelCnt; p++ {
472    if es[p] == nil {
473      continue
474    }
475    // fmt.Printf("expand(%d): %v\n", p, es[p].Expr())
476    if es[p].Expr().ExprType() != expr.ADD {
477      add := expr.NewAdd()
478      add.Insert(es[p].Expr())
479      add.CalcExprStats()
480      es[p].SetExpr(add)
481    }
482    eqns[p] = PS.Expand(es[p].Expr())
483    // fmt.Printf("Results:\n")
484    // for i, e := range eqns[p] {
485    //  fmt.Printf("%d,%d:  %v\n", p, i, e)
486    // }
487    // fmt.Println()
488  }
489  fmt.Println("\n")
490  return eqns
491}
492
493func (PS *PgeSearch) reportExpr() {
494
495  cnt := PS.cnfg.pgeRptCount
496  PS.Best.Sort()
497
498  // repot best equations
499  rpt := make(probs.ExprReportArray, cnt)
500  if PS.Best.Len() < cnt {
501    cnt = PS.Best.Len()
502  }
503  copy(rpt, PS.Best.GetQueue()[:cnt])
504
505  errSum, errCnt := 0.0, 0
506  PS.eqnsLog.Println("\n\nReport", PS.iter)
507  for i, r := range rpt {
508    PS.eqnsLog.Printf("\n%d:  %v\n", i, r)
509    if r != nil && r.Expr() != nil {
510      errSum += r.TestError()
511      errCnt++
512    }
513  }
514
515  PS.mainLog.Printf("Iter: %d  %f  %f\n", PS.iter, errSum/float64(errCnt), PS.minError)
516
517  PS.ipreLog.Println(PS.iter, PS.neqns, PS.Trie.cnt, PS.Trie.vst)
518  PS.fitnessLog.Println(PS.iter, PS.neqns, PS.Trie.cnt, PS.Trie.vst, errSum/float64(errCnt), PS.minError)
519
520  PS.commup.Rpts <- &rpt
521
522}
523
524func (PS *PgeSearch) Clean() {
525  // fmt.Printf("Cleaning PGE\n")
526
527  PS.errLogBuf.Flush()
528  PS.mainLogBuf.Flush()
529  PS.eqnsLogBuf.Flush()
530  PS.fitnessLogBuf.Flush()
531  PS.ipreLogBuf.Flush()
532
533}
534
535func (PS *PgeSearch) initLogs(logdir string) {
536  // open logs
537  PS.logDir = logdir + "pge/"
538  os.Mkdir(PS.logDir, os.ModePerm)
539  tmpF0, err5 := os.Create(PS.logDir + "pge:err.log")
540  if err5 != nil {
541    log.Fatal("couldn't create errs log")
542  }
543  PS.errLogBuf = bufio.NewWriter(tmpF0)
544  PS.errLogBuf.Flush()
545  PS.errLog = log.New(PS.errLogBuf, "", log.LstdFlags)
546
547  tmpF1, err1 := os.Create(PS.logDir + "pge:main.log")
548  if err1 != nil {
549    log.Fatal("couldn't create main log")
550  }
551  PS.mainLogBuf = bufio.NewWriter(tmpF1)
552  PS.mainLogBuf.Flush()
553  PS.mainLog = log.New(PS.mainLogBuf, "", log.LstdFlags)
554
555  tmpF2, err2 := os.Create(PS.logDir + "pge:eqns.log")
556  if err2 != nil {
557    log.Fatal("couldn't create eqns log")
558  }
559  PS.eqnsLogBuf = bufio.NewWriter(tmpF2)
560  PS.eqnsLogBuf.Flush()
561  PS.eqnsLog = log.New(PS.eqnsLogBuf, "", 0)
562
563  tmpF3, err3 := os.Create(PS.logDir + "pge:fitness.log")
564  if err3 != nil {
565    log.Fatal("couldn't create eqns log")
566  }
567  PS.fitnessLogBuf = bufio.NewWriter(tmpF3)
568  PS.fitnessLogBuf.Flush()
569  PS.fitnessLog = log.New(PS.fitnessLogBuf, "", log.Ltime|log.Lmicroseconds)
570
571  tmpF4, err4 := os.Create(PS.logDir + "pge:ipre.log")
572  if err4 != nil {
573    log.Fatal("couldn't create eqns log")
574  }
575  PS.ipreLogBuf = bufio.NewWriter(tmpF4)
576  PS.ipreLogBuf.Flush()
577  PS.ipreLog = log.New(PS.ipreLogBuf, "", log.Ltime|log.Lmicroseconds)
578}
579
580func (PS *PgeSearch) checkMessages() {
581
582  // check messages from superior
583  select {
584  case cmd, ok := <-PS.commup.Cmds:
585    if ok {
586      if cmd == -1 {
587        fmt.Println("PGE: stop sig recv'd")
588        PS.stop = true
589        return
590      }
591    }
592  default:
593    return
594  }
595}
596
597var c_input, c_ygiven []levmar.C_double
598
599func RegressExpr(E expr.Expr, P *probs.ExprProblem) (R *probs.ExprReport) {
600
601  guess := make([]float64, 0)
602  guess, eqn := E.ConvertToConstants(guess)
603
604  var coeff []float64
605  if len(guess) > 0 {
606
607    // fmt.Printf("x_dims:  %d  %d\n", x_dim, x_dim2)
608
609    // Callback version
610    coeff = levmar.LevmarExpr(eqn, P.SearchVar, P.SearchType, guess, P.Train, P.Test)
611
612    // Stack version
613    // x_dim := P.Train[0].NumDim()
614    // if c_input == nil {
615    //  ps := P.Train[0].NumPoints()
616    //  PS := len(P.Train) * ps
617    //  x_tot := PS * x_dim
618
619    //  c_input = make([]levmar.C_double, x_tot)
620    //  c_ygiven = make([]levmar.C_double, PS)
621
622    //  for i1, T := range P.Train {
623    //    for i2, p := range T.Points() {
624    //      i := i1*ps + i2
625    //      c_ygiven[i] = levmar.MakeCDouble(p.Depnd(P.SearchVar))
626    //      for i3, x_p := range p.Indeps() {
627    //        j := i1*ps*x_dim + i2*x_dim + i3
628    //        c_input[j] = levmar.MakeCDouble(x_p)
629    //      }
630    //    }
631    //  }
632    // }
633    // coeff = levmar.StackLevmarExpr(eqn, x_dim, guess, c_ygiven, c_input)
634
635    // serial := make([]int, 0)
636    // serial = eqn.StackSerial(serial)
637    // fmt.Printf("StackSerial: %v\n", serial)
638    // fmt.Printf("%v\n%v\n%v\n\n", eqn, coeff, steff)
639  }
640
641  R = new(probs.ExprReport)
642  R.SetExpr(eqn) /*.ConvertToConstantFs(coeff)*/
643  R.SetCoeff(coeff)
644  R.Expr().CalcExprStats()
645
646  // hitsL1, hitsL2, evalCnt, nanCnt, infCnt, l1_err, l2_err := scoreExpr(E, P, coeff)
647  _, _, _, trnNanCnt, _, trn_l1_err, _ := scoreExpr(E, P, P.Train, coeff)
648  _, _, tstEvalCnt, tstNanCnt, _, tst_l1_err, tst_l2_err := scoreExpr(E, P, P.Test, coeff)
649
650  R.SetTrainScore(trnNanCnt)
651  R.SetTrainError(trn_l1_err)
652
653  R.SetPredScore(tstNanCnt)
654  R.SetTestScore(tstEvalCnt)
655  R.SetTestError(tst_l1_err)
656  R.SetPredError(tst_l2_err)
657
658  return R
659}
660
661func scoreExpr(e expr.Expr, P *probs.ExprProblem, dataSets []*probs.PointSet, coeff []float64) (hitsL1, hitsL2, evalCnt, nanCnt, infCnt int, l1_err, l2_err float64) {
662  var l1_sum, l2_sum float64
663  for _, PS := range dataSets {
664    for _, p := range PS.Points() {
665      y := p.Depnd(P.SearchVar)
666      var out float64
667      if P.SearchType == probs.ExprBenchmark {
668        out = e.Eval(0, p.Indeps(), coeff, PS.SysVals())
669      } else if P.SearchType == probs.ExprDiffeq {
670        out = e.Eval(p.Indep(0), p.Indeps()[1:], coeff, PS.SysVals())
671      }
672
673      if math.IsNaN(out) {
674        nanCnt++
675        continue
676      } else if math.IsInf(out, 0) {
677        infCnt++
678        continue
679      } else {
680        evalCnt++
681      }
682
683      diff := out - y
684      l1_val := math.Abs(diff)
685      l2_val := diff * diff
686      l1_sum += l1_val
687      l2_sum += l2_val
688
689      if l1_val < P.HitRatio {
690        hitsL1++
691      }
692      if l2_val < P.HitRatio {
693        hitsL2++
694      }
695    }
696  }
697
698  if evalCnt == 0 {
699    l1_err = math.NaN()
700    l2_err = math.NaN()
701  } else {
702    fEvalCnt := float64(evalCnt + 1)
703    l1_err = l1_sum / fEvalCnt
704    l2_err = math.Sqrt(l2_sum / fEvalCnt)
705  }
706
707  return
708}
Note: See TracBrowser for help on using the repository browser.