Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2929_PrioritizedGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.PGE/3.3/go-code/go-pge/pge/pge_search.go @ 16230

Last change on this file since 16230 was 16230, checked in by hmaislin, 6 years ago

#2929: Adapted pge plugin to check for null value

File size: 19.9 KB
Line 
1package pge
2
3import (
4  "bufio"
5  "fmt"
6  "io/ioutil"
7  "log"
8  "math"
9  "os"
10  "strconv"
11  "strings"
12
13  levmar "github.com/verdverm/go-levmar"
14  config "github.com/verdverm/go-pge/config"
15  probs "github.com/verdverm/go-pge/problems"
16  expr "github.com/verdverm/go-symexpr"
17)
18
19type PgeConfig struct {
20  // search params
21  maxGen        int
22  pgeRptEpoch   int
23  pgeRptCount   int
24  pgeArchiveCap int
25
26  simprules expr.SimpRules
27  treecfg   *probs.TreeParams
28
29  // PGE specific options
30  peelCnt     int
31  sortType    probs.SortType
32  zeroEpsilon float64
33
34  initMethod string
35  growMethod string
36
37  evalrCount int
38}
39
40func pgeConfigParser(field, value string, config interface{}) (err error) {
41
42  PC := config.(*PgeConfig)
43
44  switch strings.ToUpper(field) {
45  case "MAXGEN":
46    PC.maxGen, err = strconv.Atoi(value)
47  case "PGERPTEPOCH":
48    PC.pgeRptEpoch, err = strconv.Atoi(value)
49  case "PGERPTCOUNT":
50    PC.pgeRptCount, err = strconv.Atoi(value)
51  case "PGEARCHIVECAP":
52    PC.pgeArchiveCap, err = strconv.Atoi(value)
53
54  case "PEELCOUNT":
55    PC.peelCnt, err = strconv.Atoi(value)
56
57  case "EVALRCOUNT":
58    PC.evalrCount, err = strconv.Atoi(value)
59
60  case "SORTTYPE":
61    switch strings.ToLower(value) {
62    case "paretotrainerror":
63      PC.sortType = probs.PESORT_PARETO_TRN_ERR
64    case "paretotesterror":
65      PC.sortType = probs.PESORT_PARETO_TST_ERR
66
67    default:
68      log.Printf("PGE Config Not Implemented: %s, %s\n\n", field, value)
69    }
70
71  case "ZEROEPSILON":
72    PC.zeroEpsilon, err = strconv.ParseFloat(value, 64)
73
74  default:
75    // check augillary parsable structures [only TreeParams for now]
76    if PC.treecfg == nil {
77      PC.treecfg = new(probs.TreeParams)
78    }
79    found, ferr := probs.ParseTreeParams(field, value, PC.treecfg)
80    if ferr != nil {
81      log.Fatalf("error parsing PGE - treecfg Config\n")
82      return ferr
83    }
84    if !found {
85      log.Printf("PGE Config Not Implemented: %s, %s\n\n", field, value)
86    }
87
88  }
89  return
90}
91
92type PgeSearch struct {
93  id   int
94  cnfg PgeConfig
95  prob *probs.ExprProblem
96  iter int
97  stop bool
98
99  // comm up
100  commup *probs.ExprProblemComm
101
102  // comm down
103
104  // best exprs
105  Best *probs.ReportQueue
106
107  // training data in C format
108  c_input  []levmar.C_double
109  c_ygiven []levmar.C_double
110
111  // logs
112  logDir     string
113  mainLog    *log.Logger
114  mainLogBuf *bufio.Writer
115  eqnsLog    *log.Logger
116  eqnsLogBuf *bufio.Writer
117  errLog     *log.Logger
118  errLogBuf  *bufio.Writer
119
120  fitnessLog    *log.Logger
121  fitnessLogBuf *bufio.Writer
122  ipreLog       *log.Logger
123  ipreLogBuf    *bufio.Writer
124
125  // equations visited
126  Trie  *IpreNode
127  Queue *probs.ReportQueue
128
129  // eval channels
130  eval_in  chan expr.Expr
131  eval_out chan *probs.ExprReport
132
133  // genStuff
134  GenRoots   []expr.Expr
135  GenLeafs   []expr.Expr
136  GenNodes   []expr.Expr
137  GenNonTrig []expr.Expr
138
139  // FFXish stuff
140  ffxBases []expr.Expr
141
142  // statistics
143  neqns    int
144  ipreCnt  int
145  maxSize  int
146  maxScore int
147  minError float64
148}
149
150func (PS *PgeSearch) GetMaxIter() int {
151  return PS.cnfg.maxGen
152}
153func (PS *PgeSearch) GetPeelCount() int {
154  return PS.cnfg.peelCnt
155}
156func (PS *PgeSearch) SetMaxIter(iter int) {
157  PS.cnfg.maxGen = iter
158}
159func (PS *PgeSearch) SetPeelCount(cnt int) {
160  PS.cnfg.peelCnt = cnt
161}
162func (PS *PgeSearch) SetInitMethod(init string) {
163  PS.cnfg.initMethod = init
164}
165func (PS *PgeSearch) SetGrowMethod(grow string) {
166  PS.cnfg.growMethod = grow
167}
168func (PS *PgeSearch) SetEvalrCount(cnt int) {
169  PS.cnfg.evalrCount = cnt
170}
171
172func (PS *PgeSearch) ParseConfig(filename string) {
173  fmt.Printf("Parsing PGE Config: %s\n", filename)
174  data, err := ioutil.ReadFile(filename)
175  if err != nil {
176    log.Fatal(err)
177  }
178  err = config.ParseConfig(data, pgeConfigParser, &PS.cnfg)
179  if err != nil {
180    log.Fatal(err)
181  }
182}
183
184func (PS *PgeSearch) Init(done chan int, prob *probs.ExprProblem, logdir string, input interface{}) {
185
186  //fmt.Printf("Init'n PGE\n")
187  // setup data
188
189  // open logs
190  PS.initLogs(logdir)
191 
192  PS.stop = false
193
194  // copy in common config options
195  PS.prob = prob
196  if PS.cnfg.treecfg == nil {
197    PS.cnfg.treecfg = PS.prob.TreeCfg.Clone()
198  }
199  srules := expr.DefaultRules()
200  srules.ConvertConsts = true
201  PS.cnfg.simprules = srules
202
203  //fmt.Println("Roots:   ", PS.cnfg.treecfg.RootsS)
204  //fmt.Println("Nodes:   ", PS.cnfg.treecfg.NodesS)
205  //fmt.Println("Leafs:   ", PS.cnfg.treecfg.LeafsS)
206  //fmt.Println("NonTrig: ", PS.cnfg.treecfg.NonTrigS)
207
208  PS.GenRoots = make([]expr.Expr, len(PS.cnfg.treecfg.Roots))
209  for i := 0; i < len(PS.GenRoots); i++ {
210    PS.GenRoots[i] = PS.cnfg.treecfg.Roots[i].Clone()
211  }
212  PS.GenNodes = make([]expr.Expr, len(PS.cnfg.treecfg.Nodes))
213  for i := 0; i < len(PS.GenNodes); i++ {
214    PS.GenNodes[i] = PS.cnfg.treecfg.Nodes[i].Clone()
215  }
216  PS.GenNonTrig = make([]expr.Expr, len(PS.cnfg.treecfg.NonTrig))
217  for i := 0; i < len(PS.GenNonTrig); i++ {
218    PS.GenNonTrig[i] = PS.cnfg.treecfg.NonTrig[i].Clone()
219  }
220 
221  fmt.Printf("Init1 %v\n", len(PS.cnfg.treecfg.LeafsT))
222
223  PS.GenLeafs = make([]expr.Expr, 0)
224  for _, t := range PS.cnfg.treecfg.LeafsT {
225    switch t {
226    case expr.TIME:
227      PS.GenLeafs = append(PS.GenLeafs, expr.NewTime())
228
229    case expr.VAR:
230      //fmt.Println("Use Vars: ", PS.cnfg.treecfg.UsableVars)
231      for _, i := range PS.cnfg.treecfg.UsableVars {
232        PS.GenLeafs = append(PS.GenLeafs, expr.NewVar(i))
233      }
234
235    case expr.SYSTEM:
236      for i := 0; i < PS.prob.Train[0].NumSys(); i++ {
237        PS.GenLeafs = append(PS.GenLeafs, expr.NewSystem(i))
238      }
239
240    }
241  }
242  fmt.Printf("Init2\n")
243  /*** FIX ME
244  PS.GenLeafs = make([]expr.Expr, len(PS.cnfg.treecfg.Leafs))
245  for i := 0; i < len(PS.GenLeafs); i++ {
246    PS.GenLeafs[i] = PS.cnfg.treecfg.Leafs[i].Clone()
247  }
248  ***/
249 
250  fmt.Printf("Init3\n")
251
252  //fmt.Println("Roots:   ", PS.GenRoots)
253  //fmt.Println("Nodes:   ", PS.GenNodes)
254  //fmt.Println("Leafs:   ", PS.GenLeafs)
255  //fmt.Println("NonTrig: ", PS.GenNonTrig)
256
257  // setup communication struct
258  PS.commup = input.(*probs.ExprProblemComm)
259
260  // initialize bbq
261  PS.Trie = new(IpreNode)
262  PS.Trie.val = -1
263  PS.Trie.next = make(map[int]*IpreNode)
264
265  PS.Best = probs.NewReportQueue()
266  PS.Best.SetSort(probs.GPSORT_PARETO_TST_ERR)
267  fmt.Printf("Init4\n")
268
269  PS.Queue = PS.GenInitExpr()
270  PS.Queue.SetSort(probs.PESORT_PARETO_TST_ERR)
271
272  PS.neqns = PS.Queue.Len()
273
274  PS.minError = math.Inf(1)
275
276  PS.eval_in = make(chan expr.Expr, 4048)
277  PS.eval_out = make(chan *probs.ExprReport, 4048)
278
279  fmt.Printf("Init5a\n")
280  for i := 0; i < PS.cnfg.evalrCount; i++ {
281    go PS.Evaluate()
282  }
283  fmt.Printf("Init5\n")
284}
285
286func (PS *PgeSearch) Evaluate() {
287
288  for !PS.stop {
289    e := <-PS.eval_in
290    if e == nil {
291      continue
292    }
293    re := RegressExpr(e, PS.prob)
294    //fmt.Printf("reg: %v\n", re)
295    PS.eval_out <- re
296  }
297
298}
299
300func (PS *PgeSearch) Run() {
301  fmt.Printf("Running PGE\n")
302
303  PS.Loop()
304
305  fmt.Println("PGE exitting")
306
307  PS.Clean()
308  PS.commup.Cmds <- -1
309}
310
311func (PS *PgeSearch) Loop() {
312
313  PS.checkMessages()
314  for !PS.stop {
315
316    fmt.Println("in: PS.step() ", PS.iter)
317    PS.Step()
318
319    // if PS.iter%PS.cnfg.pgeRptEpoch == 0 {
320    PS.reportExpr()
321    // }
322
323    // report current iteration
324    PS.commup.Gen <- [2]int{PS.id, PS.iter}
325    PS.iter++
326
327    PS.Clean()
328
329    PS.checkMessages()
330
331  }
332
333  // done expanding, pull the rest of the regressed solutions from the queue
334  p := 0
335  for PS.Queue.Len() > 0 {
336    e := PS.Queue.Pop().(*probs.ExprReport)
337
338    bPush := true
339    if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon {
340      // fmt.Println("No Best Push")
341      bPush = false
342    }
343
344    if bPush {
345      // fmt.Printf("pop/push(%d,%d): %v\n", p, PS.Best.Len(), e.Expr())
346      PS.Best.Push(e)
347      p++
348    }
349
350    if e.TestScore() > PS.maxScore {
351      PS.maxScore = e.TestScore()
352    }
353    if e.TestError() < PS.minError {
354      PS.minError = e.TestError()
355      fmt.Printf("EXITING New Min Error:  %v\n", e)
356    }
357    if e.Size() > PS.maxSize {
358      PS.maxSize = e.Size()
359    }
360  }
361
362  fmt.Println("PGE sending last report")
363  PS.reportExpr()
364
365}
366
367func (PS *PgeSearch) Step() {
368  loop := 0
369  eval_cnt := 0 // for channeled eval
370
371  es := PS.peel()
372
373  ex := PS.expandPeeled(es)
374
375  for cnt := range ex {
376    E := ex[cnt]
377
378    if E == nil {
379      continue
380    }
381
382    for _, e := range E {
383      if e == nil {
384        continue
385      }
386      if !PS.cnfg.treecfg.CheckExpr(e) {
387        continue
388      }
389
390      // check ipre_trie
391      serial := make([]int, 0, 64)
392      serial = e.Serial(serial)
393      ins := PS.Trie.InsertSerial(serial)
394      if !ins {
395        continue
396      }
397
398      // for serial eval
399      // re := RegressExpr(e, PS.prob)
400
401      // start channeled eval
402      PS.eval_in <- e
403      eval_cnt++
404    }
405  }
406  for i := 0; i < eval_cnt; i++ {
407    re := <-PS.eval_out
408    // end channeled eval
409
410    // check for NaN/Inf in re.error  and  if so, skip
411    if math.IsNaN(re.TestError()) || math.IsInf(re.TestError(), 0) {
412      // fmt.Printf("Bad Error\n%v\n", re)
413      continue
414    }
415
416    if re.TestError() < PS.minError {
417      PS.minError = re.TestError()
418    }
419
420    // check for coeff == 0
421    doIns := true
422    for _, c := range re.Coeff() {
423      // i > 0 for free coeff
424      if math.Abs(c) < PS.cnfg.zeroEpsilon {
425        doIns = false
426        break
427      }
428    }
429
430    if doIns {
431      re.SetProcID(PS.id)
432      re.SetIterID(PS.iter)
433      re.SetUnitID(loop)
434      re.SetUniqID(PS.neqns)
435      loop++
436      PS.neqns++
437      // fmt.Printf("Queue.Push(): %v\n%v\n\n", re.Expr(), serial)
438      // fmt.Printf("Queue.Push(): %v\n", re)
439      // fmt.Printf("Queue.Push(): %v\n", re.Expr())
440
441      PS.Queue.Push(re)
442
443    }
444  }
445  // } // for sequential eval
446  PS.Queue.Sort()
447}
448
449func (PS *PgeSearch) peel() []*probs.ExprReport {
450  es := make([]*probs.ExprReport, PS.cnfg.peelCnt)
451  for p := 0; p < PS.cnfg.peelCnt && PS.Queue.Len() > 0; p++ {
452    e := PS.Queue.Pop().(*probs.ExprReport)
453    bPush := true
454    if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon {
455      fmt.Println("No Best Push")
456      p--
457      continue
458    }
459
460    if bPush {
461      fmt.Printf("pop/push DO(%d,%d): %v\n", p, PS.Best.Len(), e.Expr())
462      PS.Best.Push(e)
463    }
464
465    es[p] = e
466
467    if e.TestScore() > PS.maxScore {
468      PS.maxScore = e.TestScore()
469    }
470    if e.TestError() < PS.minError {
471      PS.minError = e.TestError()
472      fmt.Printf("Best New Min Error:  %v\n", e)
473    }
474    if e.Size() > PS.maxSize {
475      PS.maxSize = e.Size()
476    }
477
478  }
479  return es
480}
481
482
483type PeelResult struct {
484  Es *probs.ExprReport
485
486  Nobestpush bool
487  BestNewMinErr bool
488 
489  Bestlen1 int
490  Bestlen2 int
491 
492  Coeff []float64
493  TestScore int
494 
495  Expre expr.Expr
496  ExpreRe *probs.ExprReport
497}
498
499//Problem
500func (PS *PgeSearch) SingleStep() []*PeelResult {
501  if os.Getenv("PGEDEBUG") == "1" {
502    fmt.Printf("Init PgeSearch: %#v \n", PS)
503    fmt.Printf("Init PgeProblem: %#v \n", PS.prob)
504  }
505
506
507  loop := 0
508  eval_cnt := 0
509  PR := PS.peelRes()
510 
511  es := make([]*probs.ExprReport, len(PR))
512 
513  for i, elem := range PR {
514    if(elem != nil) {
515      es[i] = elem.ExpreRe
516    //} else {
517    //  es[i] = new(probs.ExprReport)
518    }
519  }
520 
521  ex := PS.expandPeeled(es)
522
523  for cnt := range ex {
524    E := ex[cnt]
525    if E == nil {
526      continue
527    }
528    for _, e := range E {
529      if e == nil {
530        continue
531      }
532      if !PS.cnfg.treecfg.CheckExpr(e) {
533        continue
534      }
535      serial := make([]int, 0, 64)
536      serial = e.Serial(serial)
537      ins := PS.Trie.InsertSerial(serial)
538     
539      if !ins {
540        continue
541      }
542      PS.eval_in <- e
543      eval_cnt++
544    }
545  }
546  for i := 0; i < eval_cnt; i++ {
547    re := <-PS.eval_out
548    if math.IsNaN(re.TestError()) || math.IsInf(re.TestError(), 0) {
549      continue
550    }
551    if re.TestError() < PS.minError {
552      PS.minError = re.TestError()
553    }
554    doIns := true
555    for _, c := range re.Coeff() {
556      if math.Abs(c) < PS.cnfg.zeroEpsilon {
557        doIns = false
558        break
559      }
560    }
561    if doIns {
562      re.SetProcID(PS.id)
563      re.SetIterID(PS.iter)
564      re.SetUnitID(loop)
565      re.SetUniqID(PS.neqns)
566      loop++
567      PS.neqns++
568      PS.Queue.Push(re)
569    }
570  }
571  PS.Queue.Sort()
572
573  PS.reportExpr()
574  PS.iter++
575 
576  return PR
577}
578
579//Problem
580func (PS *PgeSearch) peelRes() []*PeelResult {
581  //es := make([]*probs.ExprReport, PS.cnfg.peelCnt)
582  PRes := make([]*PeelResult, PS.cnfg.peelCnt)
583 
584  for p := 0; p < PS.cnfg.peelCnt && PS.Queue.Len() > 0; p++ {
585    PR := new(PeelResult)
586     
587    PR.BestNewMinErr = false
588    PR.Nobestpush = false
589   
590    e := PS.Queue.Pop().(*probs.ExprReport)
591   
592    bPush := true
593    if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon {
594      //fmt.Println("No Best Push")
595      PR.Nobestpush = true
596      p--
597      continue
598    }
599
600    if bPush {
601      //fmt.Printf("pop/push (%d,%d): %v\n", p, PS.Best.Len(), e.Expr())
602      PS.Best.Push(e)
603    }
604
605    //es[p] = e
606    PR.Bestlen1 = p
607    PR.Bestlen2 = PS.Best.Len()
608    PR.Expre = e.Expr()
609    PR.ExpreRe = e
610    PR.Es = e
611    PR.Coeff = e.Coeff()
612    PR.TestScore = e.TestScore()
613   
614   
615    if e.TestScore() > PS.maxScore {
616      PS.maxScore = e.TestScore()
617    }
618    if e.TestError() < PS.minError {
619      PS.minError = e.TestError()
620      PR.BestNewMinErr = true
621      //fmt.Printf("Best New Min Error:  %v\n", e)
622    }
623    if e.Size() > PS.maxSize {
624      PS.maxSize = e.Size()
625    }
626   
627    PRes[p] = PR
628  }
629 
630  return PRes
631}
632
633
634func (PS *PgeSearch) expandPeeled(es []*probs.ExprReport) [][]expr.Expr {
635  eqns := make([][]expr.Expr, PS.cnfg.peelCnt)
636  for p := 0; p < PS.cnfg.peelCnt; p++ {
637    if es[p] == nil {
638      continue
639    }
640    // fmt.Printf("expand(%d): %v\n", p, es[p].Expr())
641    if es[p].Expr().ExprType() != expr.ADD {
642      add := expr.NewAdd()
643      add.Insert(es[p].Expr())
644      add.CalcExprStats()
645      es[p].SetExpr(add)
646    }
647    eqns[p] = PS.Expand(es[p].Expr())
648    // fmt.Printf("Results:\n")
649    // for i, e := range eqns[p] {
650    //  fmt.Printf("%d,%d:  %v\n", p, i, e)
651    // }
652    // fmt.Println()
653  }
654  return eqns
655}
656
657func (PS *PgeSearch) reportExpr() {
658
659  cnt := PS.cnfg.pgeRptCount
660  PS.Best.Sort()
661
662  // repot best equations
663  rpt := make(probs.ExprReportArray, cnt)
664  if PS.Best.Len() < cnt {
665    cnt = PS.Best.Len()
666  }
667  copy(rpt, PS.Best.GetQueue()[:cnt])
668
669  errSum, errCnt := 0.0, 0
670  PS.eqnsLog.Println("\n\nReport", PS.iter)
671  for i, r := range rpt {
672    PS.eqnsLog.Printf("\n%d:  %v\n", i, r)
673    if r != nil && r.Expr() != nil {
674      errSum += r.TestError()
675      errCnt++
676    }
677  }
678
679  PS.mainLog.Printf("Iter: %d  %f  %f\n", PS.iter, errSum/float64(errCnt), PS.minError)
680
681  PS.ipreLog.Println(PS.iter, PS.neqns, PS.Trie.cnt, PS.Trie.vst)
682  PS.fitnessLog.Println(PS.iter, PS.neqns, PS.Trie.cnt, PS.Trie.vst, errSum/float64(errCnt), PS.minError)
683
684  PS.commup.Rpts <- &rpt
685
686}
687
688func (PS *PgeSearch) Clean() {
689  PS.errLogBuf.Flush()
690  PS.mainLogBuf.Flush()
691  PS.eqnsLogBuf.Flush()
692  PS.fitnessLogBuf.Flush()
693  PS.ipreLogBuf.Flush()
694}
695
696func (PS *PgeSearch) initLogs(logdir string) {
697  // open logs
698  PS.logDir = logdir + "pge/"
699  os.Mkdir(PS.logDir, os.ModePerm)
700  tmpF0, err5 := os.Create(PS.logDir + "pge:err.log")
701  if err5 != nil {
702    log.Fatal("couldn't create errs log")
703  }
704  PS.errLogBuf = bufio.NewWriter(tmpF0)
705  PS.errLogBuf.Flush()
706  PS.errLog = log.New(PS.errLogBuf, "", log.LstdFlags)
707
708  tmpF1, err1 := os.Create(PS.logDir + "pge:main.log")
709  if err1 != nil {
710    log.Fatal("couldn't create main log")
711  }
712  PS.mainLogBuf = bufio.NewWriter(tmpF1)
713  PS.mainLogBuf.Flush()
714  PS.mainLog = log.New(PS.mainLogBuf, "", log.LstdFlags)
715
716  tmpF2, err2 := os.Create(PS.logDir + "pge:eqns.log")
717  if err2 != nil {
718    log.Fatal("couldn't create eqns log")
719  }
720  PS.eqnsLogBuf = bufio.NewWriter(tmpF2)
721  PS.eqnsLogBuf.Flush()
722  PS.eqnsLog = log.New(PS.eqnsLogBuf, "", 0)
723
724  tmpF3, err3 := os.Create(PS.logDir + "pge:fitness.log")
725  if err3 != nil {
726    log.Fatal("couldn't create eqns log")
727  }
728  PS.fitnessLogBuf = bufio.NewWriter(tmpF3)
729  PS.fitnessLogBuf.Flush()
730  PS.fitnessLog = log.New(PS.fitnessLogBuf, "", log.Ltime|log.Lmicroseconds)
731
732  tmpF4, err4 := os.Create(PS.logDir + "pge:ipre.log")
733  if err4 != nil {
734    log.Fatal("couldn't create eqns log")
735  }
736  PS.ipreLogBuf = bufio.NewWriter(tmpF4)
737  PS.ipreLogBuf.Flush()
738  PS.ipreLog = log.New(PS.ipreLogBuf, "", log.Ltime|log.Lmicroseconds)
739}
740
741func (PS *PgeSearch) checkMessages() {
742
743  // check messages from superior
744  select {
745  case cmd, ok := <-PS.commup.Cmds:
746    if ok {
747      if cmd == -1 {
748        fmt.Println("PGE: stop sig recv'd")
749        PS.stop = true
750        return
751      }
752    }
753  default:
754    return
755  }
756}
757
758var c_input, c_ygiven []levmar.C_double
759
760func RegressExpr(E expr.Expr, P *probs.ExprProblem) (R *probs.ExprReport) {
761
762  guess := make([]float64, 0)
763  guess, eqn := E.ConvertToConstants(guess)
764
765  var coeff []float64
766  if len(guess) > 0 {
767
768    //fmt.Printf("IMREGEXP %v %v\n", P.SearchType, P.SearchVar)
769    //fmt.Printf("REGPTrain %T | %#v | %v\n", P.Train, P.Train, P.Train)
770    //fmt.Printf("eqn %T | %#v | %v\n", eqn, eqn, eqn)
771    //fmt.Printf("guess %T | %#v | %v\n", guess, guess, guess)
772
773    // Callback version
774   
775   
776    coeff = levmar.LevmarExpr(eqn, P.SearchVar, P.SearchType, guess, P.Train, P.Test)
777   
778
779    // Stack version
780    // x_dim := P.Train[0].NumDim()
781    // if c_input == nil {
782    //  ps := P.Train[0].NumPoints()
783    //  PS := len(P.Train) * ps
784    //  x_tot := PS * x_dim
785
786    //  c_input = make([]levmar.C_double, x_tot)
787    //  c_ygiven = make([]levmar.C_double, PS)
788
789    //  for i1, T := range P.Train {
790    //    for i2, p := range T.Points() {
791    //      i := i1*ps + i2
792    //      c_ygiven[i] = levmar.MakeCDouble(p.Depnd(P.SearchVar))
793    //      for i3, x_p := range p.Indeps() {
794    //        j := i1*ps*x_dim + i2*x_dim + i3
795    //        c_input[j] = levmar.MakeCDouble(x_p)
796    //      }
797    //    }
798    //  }
799    // }
800    // coeff = levmar.StackLevmarExpr(eqn, x_dim, guess, c_ygiven, c_input)
801
802    // serial := make([]int, 0)
803    // serial = eqn.StackSerial(serial)
804    // fmt.Printf("StackSerial: %v\n", serial)
805    // fmt.Printf("%v\n%v\n%v\n\n", eqn, coeff, steff)
806  }
807
808  R = new(probs.ExprReport)
809  R.SetExpr(eqn) /*.ConvertToConstantFs(coeff)*/
810
811  R.SetCoeff(coeff)
812  R.Expr().CalcExprStats()
813
814  // hitsL1, hitsL2, evalCnt, nanCnt, infCnt, l1_err, l2_err := scoreExpr(E, P, coeff)
815  _, _, _, trnNanCnt, _, trn_l1_err, _ := scoreExpr(E, P, P.Train, coeff)
816  _, _, tstEvalCnt, tstNanCnt, _, tst_l1_err, tst_l2_err := scoreExpr(E, P, P.Test, coeff)
817
818  R.SetTrainScore(trnNanCnt)
819  R.SetTrainError(trn_l1_err)
820
821  R.SetPredScore(tstNanCnt)
822  R.SetTestScore(tstEvalCnt)
823  R.SetTestError(tst_l1_err)
824  R.SetPredError(tst_l2_err)
825
826  return R
827}
828
829func scoreExpr(e expr.Expr, P *probs.ExprProblem, dataSets []*probs.PointSet, coeff []float64) (hitsL1, hitsL2, evalCnt, nanCnt, infCnt int, l1_err, l2_err float64) {
830  var l1_sum, l2_sum float64
831  for _, PS := range dataSets {
832    for _, p := range PS.Points() {
833      y := p.Depnd(P.SearchVar)
834      var out float64
835      if P.SearchType == probs.ExprBenchmark {
836        out = e.Eval(0, p.Indeps(), coeff, PS.SysVals())
837      } else if P.SearchType == probs.ExprDiffeq {
838        out = e.Eval(p.Indep(0), p.Indeps()[1:], coeff, PS.SysVals())
839      }
840
841      if math.IsNaN(out) {
842        nanCnt++
843        continue
844      } else if math.IsInf(out, 0) {
845        infCnt++
846        continue
847      } else {
848        evalCnt++
849      }
850
851      diff := out - y
852      l1_val := math.Abs(diff)
853      l2_val := diff * diff
854      l1_sum += l1_val
855      l2_sum += l2_val
856
857      if l1_val < P.HitRatio {
858        hitsL1++
859      }
860      if l2_val < P.HitRatio {
861        hitsL2++
862      }
863    }
864  }
865
866  if evalCnt == 0 {
867    l1_err = math.NaN()
868    l2_err = math.NaN()
869  } else {
870    fEvalCnt := float64(evalCnt + 1)
871    l1_err = l1_sum / fEvalCnt
872    l2_err = math.Sqrt(l2_sum / fEvalCnt)
873  }
874
875  return
876}
877
878func (PS *PgeSearch) CreateDS(maxGen int, pgeRptEpoch int, pgeRptCount int, pgeArchiveCap int, peelCnt int, evalrCount int, zeroEpsilon float64, initMethod string, growMethod string, sortType int) {
879  PS.id = 0 //maxGen
880 
881  var PC PgeConfig
882 
883  PC.pgeRptEpoch = pgeRptEpoch
884  PC.pgeRptCount = pgeRptCount
885  PC.pgeArchiveCap = pgeArchiveCap
886  PC.peelCnt = peelCnt
887  PC.evalrCount = evalrCount
888  PC.zeroEpsilon = zeroEpsilon
889  PC.initMethod = initMethod
890  PC.growMethod = growMethod
891
892  if sortType == 1 {
893    PC.sortType = probs.PESORT_PARETO_TRN_ERR
894  } else { 
895    PC.sortType = probs.PESORT_PARETO_TST_ERR
896  }
897
898  PC.maxGen = maxGen
899  PS.cnfg = PC
900   
901  PS.iter = 0
902}
903
904func (PS *PgeSearch) SetProb(ep *probs.ExprProblem) {
905  PS.prob = ep
906}
907func (PS *PgeSearch) GetIter() int {
908  return PS.iter
909}
Note: See TracBrowser for help on using the repository browser.