1 | package pge |
---|
2 | |
---|
3 | import ( |
---|
4 | "fmt" |
---|
5 | "io/ioutil" |
---|
6 | "math" |
---|
7 | "strconv" |
---|
8 | "strings" |
---|
9 | |
---|
10 | levmar "go-levmar" |
---|
11 | config "go-pge/config" |
---|
12 | probs "go-pge/problems" |
---|
13 | expr "go-symexpr" |
---|
14 | ) |
---|
15 | |
---|
16 | type PgeConfig struct { |
---|
17 | // search params |
---|
18 | maxGen int |
---|
19 | pgeRptEpoch int |
---|
20 | pgeRptCount int |
---|
21 | pgeArchiveCap int |
---|
22 | |
---|
23 | simprules expr.SimpRules |
---|
24 | treecfg *probs.TreeParams |
---|
25 | |
---|
26 | // PGE specific options |
---|
27 | peelCnt int |
---|
28 | sortType probs.SortType |
---|
29 | zeroEpsilon float64 |
---|
30 | |
---|
31 | initMethod string |
---|
32 | growMethod string |
---|
33 | |
---|
34 | evalrCount int |
---|
35 | } |
---|
36 | |
---|
37 | func pgeConfigParser(field, value string, config interface{}) (err error) { |
---|
38 | |
---|
39 | PC := config.(*PgeConfig) |
---|
40 | |
---|
41 | switch strings.ToUpper(field) { |
---|
42 | case "MAXGEN": |
---|
43 | PC.maxGen, err = strconv.Atoi(value) |
---|
44 | case "PGERPTEPOCH": |
---|
45 | PC.pgeRptEpoch, err = strconv.Atoi(value) |
---|
46 | case "PGERPTCOUNT": |
---|
47 | PC.pgeRptCount, err = strconv.Atoi(value) |
---|
48 | case "PGEARCHIVECAP": |
---|
49 | PC.pgeArchiveCap, err = strconv.Atoi(value) |
---|
50 | |
---|
51 | case "PEELCOUNT": |
---|
52 | PC.peelCnt, err = strconv.Atoi(value) |
---|
53 | |
---|
54 | case "EVALRCOUNT": |
---|
55 | PC.evalrCount, err = strconv.Atoi(value) |
---|
56 | |
---|
57 | case "SORTTYPE": |
---|
58 | switch strings.ToLower(value) { |
---|
59 | case "paretotrainerror": |
---|
60 | PC.sortType = probs.PESORT_PARETO_TRN_ERR |
---|
61 | case "paretotesterror": |
---|
62 | PC.sortType = probs.PESORT_PARETO_TST_ERR |
---|
63 | |
---|
64 | default: |
---|
65 | } |
---|
66 | |
---|
67 | case "ZEROEPSILON": |
---|
68 | PC.zeroEpsilon, err = strconv.ParseFloat(value, 64) |
---|
69 | |
---|
70 | default: |
---|
71 | // check augillary parsable structures [only TreeParams for now] |
---|
72 | if PC.treecfg == nil { |
---|
73 | PC.treecfg = new(probs.TreeParams) |
---|
74 | } |
---|
75 | found, ferr := probs.ParseTreeParams(field, value, PC.treecfg) |
---|
76 | if ferr != nil { |
---|
77 | return ferr |
---|
78 | } |
---|
79 | if !found { |
---|
80 | } |
---|
81 | |
---|
82 | } |
---|
83 | return |
---|
84 | } |
---|
85 | |
---|
86 | type PgeSearch struct { |
---|
87 | id int |
---|
88 | cnfg PgeConfig |
---|
89 | prob *probs.ExprProblem |
---|
90 | iter int |
---|
91 | stop bool |
---|
92 | |
---|
93 | // comm up |
---|
94 | commup *probs.ExprProblemComm |
---|
95 | |
---|
96 | // comm down |
---|
97 | |
---|
98 | // best exprs |
---|
99 | Best *probs.ReportQueue |
---|
100 | |
---|
101 | // training data in C format |
---|
102 | c_input []levmar.C_double |
---|
103 | c_ygiven []levmar.C_double |
---|
104 | |
---|
105 | // logs |
---|
106 | |
---|
107 | // equations visited |
---|
108 | Trie *IpreNode |
---|
109 | Queue *probs.ReportQueue |
---|
110 | |
---|
111 | // eval channels |
---|
112 | eval_in chan expr.Expr |
---|
113 | eval_out chan *probs.ExprReport |
---|
114 | |
---|
115 | // genStuff |
---|
116 | GenRoots []expr.Expr |
---|
117 | GenLeafs []expr.Expr |
---|
118 | GenNodes []expr.Expr |
---|
119 | GenNonTrig []expr.Expr |
---|
120 | |
---|
121 | // FFXish stuff |
---|
122 | ffxBases []expr.Expr |
---|
123 | |
---|
124 | // statistics |
---|
125 | neqns int |
---|
126 | ipreCnt int |
---|
127 | maxSize int |
---|
128 | maxScore int |
---|
129 | minError float64 |
---|
130 | } |
---|
131 | |
---|
132 | func (PS *PgeSearch) GetMaxIter() int { |
---|
133 | return PS.cnfg.maxGen |
---|
134 | } |
---|
135 | func (PS *PgeSearch) GetPeelCount() int { |
---|
136 | return PS.cnfg.peelCnt |
---|
137 | } |
---|
138 | func (PS *PgeSearch) SetMaxIter(iter int) { |
---|
139 | PS.cnfg.maxGen = iter |
---|
140 | } |
---|
141 | func (PS *PgeSearch) SetPeelCount(cnt int) { |
---|
142 | PS.cnfg.peelCnt = cnt |
---|
143 | } |
---|
144 | func (PS *PgeSearch) SetInitMethod(init string) { |
---|
145 | PS.cnfg.initMethod = init |
---|
146 | } |
---|
147 | func (PS *PgeSearch) SetGrowMethod(grow string) { |
---|
148 | PS.cnfg.growMethod = grow |
---|
149 | } |
---|
150 | func (PS *PgeSearch) SetEvalrCount(cnt int) { |
---|
151 | PS.cnfg.evalrCount = cnt |
---|
152 | } |
---|
153 | |
---|
154 | func (PS *PgeSearch) ParseConfig(filename string) { |
---|
155 | fmt.Printf("Parsing PGE Config: %s\n", filename) |
---|
156 | data, err := ioutil.ReadFile(filename) |
---|
157 | if err != nil { |
---|
158 | } |
---|
159 | err = config.ParseConfig(data, pgeConfigParser, &PS.cnfg) |
---|
160 | if err != nil { |
---|
161 | } |
---|
162 | } |
---|
163 | |
---|
164 | func (PS *PgeSearch) SetSort(typen int) { |
---|
165 | if typen == 1 { |
---|
166 | PS.Best.SetSort(probs.GPSORT_PARETO_TST_ERR) |
---|
167 | PS.Best.SetSort(probs.GPSORT_PARETO_TST_ERR) |
---|
168 | } else { |
---|
169 | PS.Best.SetSort(probs.PESORT_PARETO_TST_ERR) |
---|
170 | PS.Best.SetSort(probs.PESORT_PARETO_TST_ERR) |
---|
171 | } |
---|
172 | } |
---|
173 | |
---|
174 | func (PS *PgeSearch) Init(done chan int, prob *probs.ExprProblem, logdir string, input interface{}) { |
---|
175 | |
---|
176 | fmt.Printf("Init'n PGE\n") |
---|
177 | // setup data |
---|
178 | |
---|
179 | // open logs |
---|
180 | //PS.initLogs(logdir) |
---|
181 | |
---|
182 | PS.stop = false |
---|
183 | |
---|
184 | // copy in common config options |
---|
185 | PS.prob = prob |
---|
186 | if PS.cnfg.treecfg == nil { |
---|
187 | PS.cnfg.treecfg = PS.prob.TreeCfg.Clone() |
---|
188 | } |
---|
189 | srules := expr.DefaultRules() |
---|
190 | srules.ConvertConsts = true |
---|
191 | PS.cnfg.simprules = srules |
---|
192 | |
---|
193 | fmt.Println("Roots: ", PS.cnfg.treecfg.RootsS) |
---|
194 | fmt.Println("Nodes: ", PS.cnfg.treecfg.NodesS) |
---|
195 | fmt.Println("Leafs: ", PS.cnfg.treecfg.LeafsS) |
---|
196 | fmt.Println("NonTrig: ", PS.cnfg.treecfg.NonTrigS) |
---|
197 | |
---|
198 | PS.GenRoots = make([]expr.Expr, len(PS.cnfg.treecfg.Roots)) |
---|
199 | for i := 0; i < len(PS.GenRoots); i++ { |
---|
200 | PS.GenRoots[i] = PS.cnfg.treecfg.Roots[i].Clone() |
---|
201 | } |
---|
202 | PS.GenNodes = make([]expr.Expr, len(PS.cnfg.treecfg.Nodes)) |
---|
203 | for i := 0; i < len(PS.GenNodes); i++ { |
---|
204 | PS.GenNodes[i] = PS.cnfg.treecfg.Nodes[i].Clone() |
---|
205 | } |
---|
206 | PS.GenNonTrig = make([]expr.Expr, len(PS.cnfg.treecfg.NonTrig)) |
---|
207 | for i := 0; i < len(PS.GenNonTrig); i++ { |
---|
208 | PS.GenNonTrig[i] = PS.cnfg.treecfg.NonTrig[i].Clone() |
---|
209 | } |
---|
210 | |
---|
211 | PS.GenLeafs = make([]expr.Expr, 0) |
---|
212 | for _, t := range PS.cnfg.treecfg.LeafsT { |
---|
213 | switch t { |
---|
214 | case expr.TIME: |
---|
215 | PS.GenLeafs = append(PS.GenLeafs, expr.NewTime()) |
---|
216 | |
---|
217 | case expr.VAR: |
---|
218 | fmt.Println("Use Vars: ", PS.cnfg.treecfg.UsableVars) |
---|
219 | for _, i := range PS.cnfg.treecfg.UsableVars { |
---|
220 | PS.GenLeafs = append(PS.GenLeafs, expr.NewVar(i)) |
---|
221 | } |
---|
222 | |
---|
223 | case expr.SYSTEM: |
---|
224 | for i := 0; i < PS.prob.Train[0].NumSys(); i++ { |
---|
225 | PS.GenLeafs = append(PS.GenLeafs, expr.NewSystem(i)) |
---|
226 | } |
---|
227 | |
---|
228 | } |
---|
229 | } |
---|
230 | |
---|
231 | /*** FIX ME |
---|
232 | PS.GenLeafs = make([]expr.Expr, len(PS.cnfg.treecfg.Leafs)) |
---|
233 | for i := 0; i < len(PS.GenLeafs); i++ { |
---|
234 | PS.GenLeafs[i] = PS.cnfg.treecfg.Leafs[i].Clone() |
---|
235 | } |
---|
236 | ***/ |
---|
237 | |
---|
238 | //fmt.Println("Roots: ", PS.GenRoots) |
---|
239 | //fmt.Println("Nodes: ", PS.GenNodes) |
---|
240 | //fmt.Println("Leafs: ", PS.GenLeafs) |
---|
241 | //fmt.Println("NonTrig: ", PS.GenNonTrig) |
---|
242 | |
---|
243 | // setup communication struct |
---|
244 | PS.commup = input.(*probs.ExprProblemComm) |
---|
245 | |
---|
246 | // initialize bbq |
---|
247 | PS.Trie = new(IpreNode) |
---|
248 | PS.Trie.val = -1 |
---|
249 | PS.Trie.next = make(map[int]*IpreNode) |
---|
250 | |
---|
251 | PS.Best = probs.NewReportQueue() |
---|
252 | PS.Best.SetSort(probs.GPSORT_PARETO_TST_ERR) |
---|
253 | |
---|
254 | PS.Queue = PS.GenInitExpr() |
---|
255 | PS.Queue.SetSort(probs.PESORT_PARETO_TST_ERR) |
---|
256 | |
---|
257 | PS.neqns = PS.Queue.Len() |
---|
258 | |
---|
259 | PS.minError = math.Inf(1) |
---|
260 | |
---|
261 | PS.eval_in = make(chan expr.Expr, 4048) |
---|
262 | PS.eval_out = make(chan *probs.ExprReport, 4048) |
---|
263 | |
---|
264 | for i := 0; i < PS.cnfg.evalrCount; i++ { |
---|
265 | go PS.Evaluate() |
---|
266 | } |
---|
267 | } |
---|
268 | |
---|
269 | func (PS *PgeSearch) Stop() { |
---|
270 | PS.stop = true |
---|
271 | } |
---|
272 | |
---|
273 | func (PS *PgeSearch) Evaluate() { |
---|
274 | |
---|
275 | for !PS.stop { |
---|
276 | e := <-PS.eval_in |
---|
277 | if e == nil { |
---|
278 | continue |
---|
279 | } |
---|
280 | //re := |
---|
281 | //fmt.Printf("reg: %v\n", re) |
---|
282 | PS.eval_out <- RegressExpr(e, PS.prob) |
---|
283 | } |
---|
284 | |
---|
285 | } |
---|
286 | |
---|
287 | func (PS *PgeSearch) Run() { |
---|
288 | fmt.Printf("Running PGE\n") |
---|
289 | |
---|
290 | PS.Loop() |
---|
291 | |
---|
292 | fmt.Println("PGE exitting") |
---|
293 | |
---|
294 | PS.Clean() |
---|
295 | PS.commup.Cmds <- -1 |
---|
296 | } |
---|
297 | |
---|
298 | func (PS *PgeSearch) Loop() { |
---|
299 | |
---|
300 | PS.checkMessages() |
---|
301 | for !PS.stop { |
---|
302 | |
---|
303 | fmt.Println("in: PS.step() ", PS.iter) |
---|
304 | PS.Step() |
---|
305 | |
---|
306 | // if PS.iter%PS.cnfg.pgeRptEpoch == 0 { |
---|
307 | PS.ReportExpr(false) |
---|
308 | // } |
---|
309 | |
---|
310 | // report current iteration |
---|
311 | PS.commup.Gen <- [2]int{PS.id, PS.iter} |
---|
312 | PS.iter++ |
---|
313 | |
---|
314 | PS.Clean() |
---|
315 | |
---|
316 | PS.checkMessages() |
---|
317 | |
---|
318 | } |
---|
319 | |
---|
320 | // done expanding, pull the rest of the regressed solutions from the queue |
---|
321 | p := 0 |
---|
322 | for PS.Queue.Len() > 0 { |
---|
323 | e := PS.Queue.Pop().(*probs.ExprReport) |
---|
324 | |
---|
325 | bPush := true |
---|
326 | if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon { |
---|
327 | // fmt.Println("No Best Push") |
---|
328 | bPush = false |
---|
329 | } |
---|
330 | |
---|
331 | if bPush { |
---|
332 | // fmt.Printf("pop/push(%d,%d): %v\n", p, PS.Best.Len(), e.Expr()) |
---|
333 | PS.Best.Push(e) |
---|
334 | p++ |
---|
335 | } |
---|
336 | |
---|
337 | if e.TestScore() > PS.maxScore { |
---|
338 | PS.maxScore = e.TestScore() |
---|
339 | } |
---|
340 | if e.TestError() < PS.minError { |
---|
341 | PS.minError = e.TestError() |
---|
342 | fmt.Printf("EXITING New Min Error: %v\n", e) |
---|
343 | } |
---|
344 | if e.Size() > PS.maxSize { |
---|
345 | PS.maxSize = e.Size() |
---|
346 | } |
---|
347 | } |
---|
348 | |
---|
349 | fmt.Println("PGE sending last report") |
---|
350 | PS.ReportExpr(false) |
---|
351 | |
---|
352 | } |
---|
353 | |
---|
354 | type PeelResult struct { |
---|
355 | Es *probs.ExprReport |
---|
356 | |
---|
357 | Nobestpush bool |
---|
358 | BestNewMinErr bool |
---|
359 | |
---|
360 | Bestlen1 int |
---|
361 | Bestlen2 int |
---|
362 | |
---|
363 | Coeff []float64 |
---|
364 | TestScore int |
---|
365 | |
---|
366 | Expre expr.Expr |
---|
367 | ExpreRe *probs.ExprReport |
---|
368 | } |
---|
369 | |
---|
370 | func (PS *PgeSearch) Step() int { |
---|
371 | |
---|
372 | loop := 0 |
---|
373 | eval_cnt := 0 // for channeled eval |
---|
374 | |
---|
375 | es := PS.peel() |
---|
376 | |
---|
377 | ex := PS.expandPeeled(es) |
---|
378 | |
---|
379 | cnt_ins := 0 |
---|
380 | |
---|
381 | inserts := make(probs.ExprReportArray, 0) |
---|
382 | for cnt := range ex { |
---|
383 | E := ex[cnt] |
---|
384 | |
---|
385 | if E == nil { |
---|
386 | continue |
---|
387 | } |
---|
388 | |
---|
389 | for _, e := range E { |
---|
390 | if e == nil { |
---|
391 | continue |
---|
392 | } |
---|
393 | if !PS.cnfg.treecfg.CheckExpr(e) { |
---|
394 | continue |
---|
395 | } |
---|
396 | |
---|
397 | // check ipre_trie |
---|
398 | serial := make([]int, 0, 64) |
---|
399 | serial = e.Serial(serial) |
---|
400 | ins := PS.Trie.InsertSerial(serial) |
---|
401 | if !ins { |
---|
402 | continue |
---|
403 | } |
---|
404 | |
---|
405 | // for serial eval |
---|
406 | //re := RegressExpr(e, PS.prob) //needed for TestScore calc via RegressExpr |
---|
407 | //inserts = append(inserts, re) |
---|
408 | |
---|
409 | // start channeled eval |
---|
410 | PS.eval_in <- e |
---|
411 | eval_cnt++ |
---|
412 | } |
---|
413 | } |
---|
414 | |
---|
415 | for i := 0; i < eval_cnt; i++ { |
---|
416 | re := <-PS.eval_out |
---|
417 | // end channeled eval |
---|
418 | |
---|
419 | // check for NaN/Inf in re.error and if so, skip |
---|
420 | if math.IsNaN(re.TestError()) || math.IsInf(re.TestError(), 0) { |
---|
421 | // fmt.Printf("Bad Error\n%v\n", re) |
---|
422 | continue |
---|
423 | } |
---|
424 | |
---|
425 | if re.TestError() < PS.minError { |
---|
426 | PS.minError = re.TestError() |
---|
427 | } |
---|
428 | |
---|
429 | // check for coeff == 0 |
---|
430 | doIns := true |
---|
431 | for _, c := range re.Coeff() { |
---|
432 | // i > 0 for free coeff |
---|
433 | if math.Abs(c) < PS.cnfg.zeroEpsilon { |
---|
434 | doIns = false |
---|
435 | break |
---|
436 | } |
---|
437 | } |
---|
438 | |
---|
439 | //fmt.Printf("StepQueue.Push(): %v\n", re) |
---|
440 | //fmt.Printf("StepQueue.Push(): %v\n", re.Expr()) |
---|
441 | |
---|
442 | if doIns { |
---|
443 | re.SetProcID(PS.id) |
---|
444 | re.SetIterID(PS.iter) |
---|
445 | re.SetUnitID(loop) |
---|
446 | re.SetUniqID(PS.neqns) |
---|
447 | loop++ |
---|
448 | PS.neqns++ |
---|
449 | // fmt.Printf("Queue.Push(): %v\n%v\n\n", re.Expr(), serial) |
---|
450 | // fmt.Printf("Queue.Push(): %v\n", re) |
---|
451 | // fmt.Printf("Queue.Push(): %v\n", re.Expr()) |
---|
452 | cnt_ins++ |
---|
453 | PS.Queue.Push(re) |
---|
454 | //fmt.Printf("Testscore: %v\n", re.TestScore()) |
---|
455 | //PS.commup.Rpts <- &re //sort missing! > 3 ergs |
---|
456 | } |
---|
457 | } |
---|
458 | |
---|
459 | PS.Queue.Sort() //3 besten werden beim naechsten peel ausgegeben |
---|
460 | |
---|
461 | for p := 0; p < PS.cnfg.peelCnt && PS.Queue.Len() > 0; p++ { |
---|
462 | val := PS.Queue.Pop().(*probs.ExprReport) |
---|
463 | inserts = append(inserts, val) |
---|
464 | } |
---|
465 | |
---|
466 | for _, e := range inserts { |
---|
467 | PS.Queue.Push(e) |
---|
468 | } |
---|
469 | |
---|
470 | PS.Queue.Sort() |
---|
471 | |
---|
472 | PS.commup.Res <- &inserts |
---|
473 | |
---|
474 | PS.ReportExpr(false) |
---|
475 | PS.iter++ |
---|
476 | |
---|
477 | return len(inserts) |
---|
478 | } |
---|
479 | |
---|
480 | func (PS *PgeSearch) peel() []*probs.ExprReport { |
---|
481 | es := make([]*probs.ExprReport, PS.cnfg.peelCnt) |
---|
482 | |
---|
483 | rpt := make(probs.ExprReportArray, 0) |
---|
484 | |
---|
485 | for p := 0; p < PS.cnfg.peelCnt && PS.Queue.Len() > 0; p++ { |
---|
486 | |
---|
487 | e := PS.Queue.Pop().(*probs.ExprReport) |
---|
488 | |
---|
489 | bPush := true |
---|
490 | if len(e.Coeff()) == 1 && math.Abs(e.Coeff()[0]) < PS.cnfg.zeroEpsilon { |
---|
491 | fmt.Println("No Best Push") |
---|
492 | p-- |
---|
493 | continue |
---|
494 | } |
---|
495 | |
---|
496 | if bPush { |
---|
497 | fmt.Printf("BEST PUSH/push(%d,%d): %v\n", p, PS.Best.Len(), e.Expr()) |
---|
498 | PS.Best.Push(e) |
---|
499 | rpt = append(rpt, e) |
---|
500 | } |
---|
501 | |
---|
502 | es[p] = e |
---|
503 | |
---|
504 | if e.TestScore() > PS.maxScore { |
---|
505 | fmt.Printf("Testscore: %v\n", e.TestScore()) |
---|
506 | PS.maxScore = e.TestScore() |
---|
507 | } |
---|
508 | if e.TestError() < PS.minError { |
---|
509 | PS.minError = e.TestError() |
---|
510 | fmt.Printf("Best New Min Error: %v\n", e) |
---|
511 | } |
---|
512 | if e.Size() > PS.maxSize { |
---|
513 | PS.maxSize = e.Size() |
---|
514 | } |
---|
515 | |
---|
516 | } |
---|
517 | |
---|
518 | //fmt.Printf("sand %d best pushes in peel\n", len(rpt)) |
---|
519 | //PS.commup.Res <- &rpt |
---|
520 | |
---|
521 | _ = rpt |
---|
522 | |
---|
523 | return es |
---|
524 | } |
---|
525 | |
---|
526 | func (PS *PgeSearch) expandPeeled(es []*probs.ExprReport) [][]expr.Expr { |
---|
527 | eqns := make([][]expr.Expr, PS.cnfg.peelCnt) |
---|
528 | for p := 0; p < PS.cnfg.peelCnt; p++ { |
---|
529 | if es[p] == nil { |
---|
530 | continue |
---|
531 | } |
---|
532 | // fmt.Printf("expand(%d): %v\n", p, es[p].Expr()) |
---|
533 | if es[p].Expr().ExprType() != expr.ADD { |
---|
534 | add := expr.NewAdd() |
---|
535 | add.Insert(es[p].Expr()) |
---|
536 | add.CalcExprStats() |
---|
537 | es[p].SetExpr(add) |
---|
538 | } |
---|
539 | eqns[p] = PS.Expand(es[p].Expr()) |
---|
540 | // fmt.Printf("Results:\n") |
---|
541 | // for i, e := range eqns[p] { |
---|
542 | // fmt.Printf("%d,%d: %v\n", p, i, e) |
---|
543 | // } |
---|
544 | // fmt.Println() |
---|
545 | } |
---|
546 | fmt.Println("\n") |
---|
547 | return eqns |
---|
548 | } |
---|
549 | |
---|
550 | func (PS *PgeSearch) ReportExpr(writeChannel bool) { |
---|
551 | |
---|
552 | cnt := PS.cnfg.pgeRptCount |
---|
553 | PS.Best.Sort() |
---|
554 | |
---|
555 | // report best equations |
---|
556 | rpt := make(probs.ExprReportArray, cnt) |
---|
557 | if PS.Best.Len() < cnt { |
---|
558 | cnt = PS.Best.Len() |
---|
559 | } |
---|
560 | copy(rpt, PS.Best.GetQueue()[:cnt]) |
---|
561 | |
---|
562 | errSum, errCnt := 0.0, 0 |
---|
563 | for _, r := range rpt { |
---|
564 | if r != nil && r.Expr() != nil { |
---|
565 | errSum += r.TestError() |
---|
566 | errCnt++ |
---|
567 | } |
---|
568 | } |
---|
569 | |
---|
570 | if writeChannel { |
---|
571 | PS.commup.Rpts <- &rpt |
---|
572 | } |
---|
573 | } |
---|
574 | |
---|
575 | func (PS *PgeSearch) FirstPeel() { |
---|
576 | PS.peel() |
---|
577 | } |
---|
578 | |
---|
579 | func (PS *PgeSearch) Clean() { |
---|
580 | |
---|
581 | } |
---|
582 | |
---|
583 | func (PS *PgeSearch) initLogs(logdir string) { |
---|
584 | |
---|
585 | } |
---|
586 | |
---|
587 | func (PS *PgeSearch) checkMessages() { |
---|
588 | |
---|
589 | // check messages from superior |
---|
590 | select { |
---|
591 | case cmd, ok := <-PS.commup.Cmds: |
---|
592 | if ok { |
---|
593 | if cmd == -1 { |
---|
594 | fmt.Println("PGE: stop sig recv'd") |
---|
595 | PS.stop = true |
---|
596 | return |
---|
597 | } |
---|
598 | } |
---|
599 | default: |
---|
600 | return |
---|
601 | } |
---|
602 | } |
---|
603 | |
---|
604 | var c_input, c_ygiven []levmar.C_double |
---|
605 | |
---|
606 | func RegressExpr(E expr.Expr, P *probs.ExprProblem) (R *probs.ExprReport) { |
---|
607 | |
---|
608 | guess := make([]float64, 0) |
---|
609 | guess, eqn := E.ConvertToConstants(guess) |
---|
610 | |
---|
611 | var coeff []float64 |
---|
612 | if len(guess) > 0 { |
---|
613 | |
---|
614 | // fmt.Printf("x_dims: %d %d\n", x_dim, x_dim2) |
---|
615 | |
---|
616 | // Callback version |
---|
617 | coeff = levmar.LevmarExpr(eqn, P.SearchVar, P.SearchType, guess, P.Train, P.Test) |
---|
618 | |
---|
619 | // Stack version |
---|
620 | // x_dim := P.Train[0].NumDim() |
---|
621 | // if c_input == nil { |
---|
622 | // ps := P.Train[0].NumPoints() |
---|
623 | // PS := len(P.Train) * ps |
---|
624 | // x_tot := PS * x_dim |
---|
625 | |
---|
626 | // c_input = make([]levmar.C_double, x_tot) |
---|
627 | // c_ygiven = make([]levmar.C_double, PS) |
---|
628 | |
---|
629 | // for i1, T := range P.Train { |
---|
630 | // for i2, p := range T.Points() { |
---|
631 | // i := i1*ps + i2 |
---|
632 | // c_ygiven[i] = levmar.MakeCDouble(p.Depnd(P.SearchVar)) |
---|
633 | // for i3, x_p := range p.Indeps() { |
---|
634 | // j := i1*ps*x_dim + i2*x_dim + i3 |
---|
635 | // c_input[j] = levmar.MakeCDouble(x_p) |
---|
636 | // } |
---|
637 | // } |
---|
638 | // } |
---|
639 | // } |
---|
640 | // coeff = levmar.StackLevmarExpr(eqn, x_dim, guess, c_ygiven, c_input) |
---|
641 | |
---|
642 | // serial := make([]int, 0) |
---|
643 | // serial = eqn.StackSerial(serial) |
---|
644 | // fmt.Printf("StackSerial: %v\n", serial) |
---|
645 | // fmt.Printf("%v\n%v\n%v\n\n", eqn, coeff, steff) |
---|
646 | } |
---|
647 | |
---|
648 | R = new(probs.ExprReport) |
---|
649 | R.SetExpr(eqn) /*.ConvertToConstantFs(coeff)*/ |
---|
650 | R.SetCoeff(coeff) |
---|
651 | R.Expr().CalcExprStats() |
---|
652 | |
---|
653 | // hitsL1, hitsL2, evalCnt, nanCnt, infCnt, l1_err, l2_err := scoreExpr(E, P, coeff) |
---|
654 | _, _, _, trnNanCnt, _, trn_l1_err, _ := scoreExpr(E, P, P.Train, coeff) |
---|
655 | _, _, tstEvalCnt, tstNanCnt, _, tst_l1_err, tst_l2_err := scoreExpr(E, P, P.Test, coeff) |
---|
656 | |
---|
657 | R.SetTrainScore(trnNanCnt) |
---|
658 | R.SetTrainError(trn_l1_err) |
---|
659 | |
---|
660 | R.SetPredScore(tstNanCnt) |
---|
661 | R.SetTestScore(tstEvalCnt) |
---|
662 | R.SetTestError(tst_l1_err) |
---|
663 | R.SetPredError(tst_l2_err) |
---|
664 | |
---|
665 | return R |
---|
666 | } |
---|
667 | |
---|
668 | func scoreExpr(e expr.Expr, P *probs.ExprProblem, dataSets []*probs.PointSet, coeff []float64) (hitsL1, hitsL2, evalCnt, nanCnt, infCnt int, l1_err, l2_err float64) { |
---|
669 | var l1_sum, l2_sum float64 |
---|
670 | for _, PS := range dataSets { |
---|
671 | for _, p := range PS.Points() { |
---|
672 | y := p.Depnd(P.SearchVar) |
---|
673 | var out float64 |
---|
674 | if P.SearchType == probs.ExprBenchmark { |
---|
675 | out = e.Eval(0, p.Indeps(), coeff, PS.SysVals()) |
---|
676 | } else if P.SearchType == probs.ExprDiffeq { |
---|
677 | out = e.Eval(p.Indep(0), p.Indeps()[1:], coeff, PS.SysVals()) |
---|
678 | } |
---|
679 | |
---|
680 | if math.IsNaN(out) { |
---|
681 | nanCnt++ |
---|
682 | continue |
---|
683 | } else if math.IsInf(out, 0) { |
---|
684 | infCnt++ |
---|
685 | continue |
---|
686 | } else { |
---|
687 | evalCnt++ |
---|
688 | } |
---|
689 | |
---|
690 | diff := out - y |
---|
691 | l1_val := math.Abs(diff) |
---|
692 | l2_val := diff * diff |
---|
693 | l1_sum += l1_val |
---|
694 | l2_sum += l2_val |
---|
695 | |
---|
696 | if l1_val < P.HitRatio { |
---|
697 | hitsL1++ |
---|
698 | } |
---|
699 | if l2_val < P.HitRatio { |
---|
700 | hitsL2++ |
---|
701 | } |
---|
702 | } |
---|
703 | } |
---|
704 | |
---|
705 | if evalCnt == 0 { |
---|
706 | l1_err = math.NaN() |
---|
707 | l2_err = math.NaN() |
---|
708 | } else { |
---|
709 | fEvalCnt := float64(evalCnt + 1) |
---|
710 | l1_err = l1_sum / fEvalCnt |
---|
711 | l2_err = math.Sqrt(l2_sum / fEvalCnt) |
---|
712 | } |
---|
713 | |
---|
714 | return |
---|
715 | } |
---|
716 | |
---|
717 | func (PS *PgeSearch) CreateDS(maxGen int, pgeRptEpoch int, pgeRptCount int, pgeArchiveCap int, peelCnt int, evalrCount int, zeroEpsilon float64, initMethod string, growMethod string, sortType int) { |
---|
718 | PS.id = 0 |
---|
719 | |
---|
720 | var PC PgeConfig |
---|
721 | |
---|
722 | PC.pgeRptEpoch = pgeRptEpoch |
---|
723 | PC.pgeRptCount = pgeRptCount |
---|
724 | PC.pgeArchiveCap = pgeArchiveCap |
---|
725 | PC.peelCnt = peelCnt |
---|
726 | PC.evalrCount = evalrCount |
---|
727 | PC.zeroEpsilon = zeroEpsilon |
---|
728 | PC.initMethod = initMethod |
---|
729 | PC.growMethod = growMethod |
---|
730 | |
---|
731 | if sortType == 1 { |
---|
732 | PC.sortType = probs.PESORT_PARETO_TRN_ERR |
---|
733 | } else { |
---|
734 | PC.sortType = probs.PESORT_PARETO_TST_ERR |
---|
735 | } |
---|
736 | |
---|
737 | PC.maxGen = maxGen |
---|
738 | PS.cnfg = PC |
---|
739 | PS.stop = false |
---|
740 | PS.iter = 0 |
---|
741 | } |
---|
742 | |
---|
743 | func (PS *PgeSearch) SetProb(ep *probs.ExprProblem) { |
---|
744 | PS.prob = ep |
---|
745 | } |
---|
746 | func (PS *PgeSearch) GetIter() int { |
---|
747 | return PS.iter |
---|
748 | } |
---|