1 | package main
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "bufio"
|
---|
5 | "fmt"
|
---|
6 | "io/ioutil"
|
---|
7 | "log"
|
---|
8 | "os"
|
---|
9 | "sort"
|
---|
10 | "strings"
|
---|
11 | "time"
|
---|
12 |
|
---|
13 | config "github.com/verdverm/go-pge/config"
|
---|
14 | pge "github.com/verdverm/go-pge/pge"
|
---|
15 | probs "github.com/verdverm/go-pge/problems"
|
---|
16 | )
|
---|
17 |
|
---|
18 | // defines the interface to a search type [GP,PE]
|
---|
19 | type Search interface {
|
---|
20 |
|
---|
21 | // parse a params file
|
---|
22 | ParseConfig(filename string)
|
---|
23 |
|
---|
24 | // initialize the search, sending signal on chan when done
|
---|
25 | // the input will be something for the search to connect to
|
---|
26 | // in order to provide updates, be monitored, and receive control signals
|
---|
27 | Init(done chan int, prob *probs.ExprProblem, logdir string, input interface{})
|
---|
28 |
|
---|
29 | // start the actual search procedure (Init is required before a new call to Run)
|
---|
30 | Run()
|
---|
31 |
|
---|
32 | // clean up internal structures (Init is required before a new call to Run)
|
---|
33 | Clean()
|
---|
34 | }
|
---|
35 |
|
---|
36 | // parameters to a main search, which sets up the global system
|
---|
37 | // and instructs in where to find the sub-searches
|
---|
38 | type mainConfig struct {
|
---|
39 | dataDir string
|
---|
40 | cfgDir string
|
---|
41 | logDir string
|
---|
42 |
|
---|
43 | probCfg string
|
---|
44 | srchCfg []string
|
---|
45 | }
|
---|
46 |
|
---|
47 | func mainConfigParser(field, value string, config interface{}) (err error) {
|
---|
48 |
|
---|
49 | DC := config.(*mainConfig)
|
---|
50 |
|
---|
51 | switch strings.ToUpper(field) {
|
---|
52 | case "CONFIGDIR":
|
---|
53 | DC.cfgDir = value
|
---|
54 | case "DATADIR":
|
---|
55 | DC.dataDir = value
|
---|
56 | case "LOGDIR":
|
---|
57 | DC.logDir = value
|
---|
58 |
|
---|
59 | case "PROBLEMCFG":
|
---|
60 | DC.probCfg = value
|
---|
61 | case "SEARCHCFG":
|
---|
62 | DC.srchCfg = strings.Fields(value)
|
---|
63 | default:
|
---|
64 | log.Printf("Main Not Implemented %s, %s\n\n", field, value)
|
---|
65 |
|
---|
66 | }
|
---|
67 | return err
|
---|
68 | }
|
---|
69 |
|
---|
70 | // defines the global level search, which may use different sub-searches one or more times
|
---|
71 | type MainSearch struct {
|
---|
72 | cnfg mainConfig
|
---|
73 |
|
---|
74 | // problem and best results
|
---|
75 | prob *probs.ExprProblem
|
---|
76 | eqns probs.ExprReportArray
|
---|
77 | per_eqns []*probs.ExprReportArray
|
---|
78 |
|
---|
79 | // sub-searches and comm
|
---|
80 | srch []Search
|
---|
81 | comm []*probs.ExprProblemComm
|
---|
82 | iter []int
|
---|
83 |
|
---|
84 | // logs
|
---|
85 | logDir string
|
---|
86 | mainLog *log.Logger
|
---|
87 | mainLogBuf *bufio.Writer
|
---|
88 | eqnsLog *log.Logger
|
---|
89 | eqnsLogBuf *bufio.Writer
|
---|
90 | errLog *log.Logger
|
---|
91 | errLogBuf *bufio.Writer
|
---|
92 | }
|
---|
93 |
|
---|
94 | func (DS *MainSearch) ParseConfig(filename string) {
|
---|
95 | fmt.Printf("Parsing Main Config: %s\n", filename)
|
---|
96 | data, err := ioutil.ReadFile(filename)
|
---|
97 | if err != nil {
|
---|
98 | log.Fatal(err)
|
---|
99 | }
|
---|
100 |
|
---|
101 | err = config.ParseConfig(data, mainConfigParser, &DS.cnfg)
|
---|
102 | if err != nil {
|
---|
103 | log.Fatal(err)
|
---|
104 | }
|
---|
105 | fmt.Printf("%v\n", DS.cnfg)
|
---|
106 | }
|
---|
107 |
|
---|
108 | func (DS *MainSearch) Init(done chan int, input interface{}) {
|
---|
109 | fmt.Printf("Init'n PGE1\n----------\n")
|
---|
110 |
|
---|
111 | DC := DS.cnfg
|
---|
112 |
|
---|
113 | // read and setup problem
|
---|
114 | eprob := new(probs.ExprProblem)
|
---|
115 |
|
---|
116 | fmt.Printf("Parsing Problem Config: %s\n", DC.probCfg)
|
---|
117 | data, err := ioutil.ReadFile(DC.cfgDir + DC.probCfg)
|
---|
118 | if err != nil {
|
---|
119 | log.Fatal(err)
|
---|
120 | }
|
---|
121 | err = config.ParseConfig(data, probs.ProbConfigParser, eprob)
|
---|
122 | if err != nil {
|
---|
123 | log.Fatal(err)
|
---|
124 | }
|
---|
125 | fmt.Printf("Prob: %v\n", eprob)
|
---|
126 | fmt.Printf("TCfg: %v\n\n", eprob.TreeCfg)
|
---|
127 |
|
---|
128 | // setup log dir and open main log files
|
---|
129 | DC.logDir += eprob.Name + "/"
|
---|
130 | if DC.srchCfg[0][:4] == "pge1" {
|
---|
131 | DC.logDir += "pge1/"
|
---|
132 | } else if DC.srchCfg[0][:3] == "pge" {
|
---|
133 | DC.logDir += "pge/"
|
---|
134 | }
|
---|
135 |
|
---|
136 | os.MkdirAll(DC.logDir, os.ModePerm)
|
---|
137 |
|
---|
138 | now := time.Now()
|
---|
139 | fmt.Println("LogDir: ", DC.logDir)
|
---|
140 | os.MkdirAll(DC.logDir, os.ModePerm)
|
---|
141 | DS.initLogs(DC.logDir)
|
---|
142 |
|
---|
143 | DS.mainLog.Println(DC.logDir, now)
|
---|
144 |
|
---|
145 | // // setup data
|
---|
146 | fmt.Printf("Setting up problem: %s\n", eprob.Name)
|
---|
147 |
|
---|
148 | eprob.Train = make([]*probs.PointSet, len(eprob.TrainFns))
|
---|
149 | for i, fn := range eprob.TrainFns {
|
---|
150 | fmt.Printf("Reading Trainging File: %s\n", fn)
|
---|
151 | eprob.Train[i] = new(probs.PointSet)
|
---|
152 | if strings.HasSuffix(fn, ".dataF2") || strings.HasSuffix(fn, ".mat") {
|
---|
153 | eprob.Train[i].ReadLakeFile(DC.dataDir + fn)
|
---|
154 | } else {
|
---|
155 | eprob.Train[i].ReadPointSet(DC.dataDir + fn)
|
---|
156 | }
|
---|
157 | }
|
---|
158 | eprob.Test = make([]*probs.PointSet, len(eprob.TestFns))
|
---|
159 | for i, fn := range eprob.TestFns {
|
---|
160 | fmt.Printf("Reading Testing File: %s\n", fn)
|
---|
161 | eprob.Test[i] = new(probs.PointSet)
|
---|
162 | if strings.HasSuffix(fn, ".dataF2") || strings.HasSuffix(fn, ".mat") {
|
---|
163 | eprob.Test[i].ReadLakeFile(DC.dataDir + fn)
|
---|
164 | } else {
|
---|
165 | eprob.Test[i].ReadPointSet(DC.dataDir + fn)
|
---|
166 | }
|
---|
167 | }
|
---|
168 |
|
---|
169 | DS.prob = eprob
|
---|
170 | fmt.Println()
|
---|
171 |
|
---|
172 | // read search configs
|
---|
173 | for _, cfg := range DC.srchCfg {
|
---|
174 | if cfg[:4] == "pge1" {
|
---|
175 | GS := new(pge.PgeSearch)
|
---|
176 | GS.ParseConfig(DC.cfgDir + cfg)
|
---|
177 | DS.srch = append(DS.srch, GS)
|
---|
178 | } else if cfg[:3] == "pge" {
|
---|
179 | PS := new(pge.PgeSearch)
|
---|
180 | PS.ParseConfig(DC.cfgDir + cfg)
|
---|
181 | DS.srch = append(DS.srch, PS)
|
---|
182 |
|
---|
183 | /************/
|
---|
184 | // temporary hack
|
---|
185 | DS.prob.MaxIter = PS.GetMaxIter()
|
---|
186 | /************/
|
---|
187 | if *arg_pge_iter >= 0 {
|
---|
188 | DS.prob.MaxIter = *arg_pge_iter
|
---|
189 | PS.SetMaxIter(*arg_pge_iter)
|
---|
190 | }
|
---|
191 | if *arg_pge_peel >= 0 {
|
---|
192 | PS.SetPeelCount(*arg_pge_peel)
|
---|
193 | }
|
---|
194 | if *arg_pge_init != "" {
|
---|
195 | PS.SetInitMethod(*arg_pge_init)
|
---|
196 | }
|
---|
197 | if *arg_pge_grow != "" {
|
---|
198 | PS.SetGrowMethod(*arg_pge_grow)
|
---|
199 | }
|
---|
200 | PS.SetEvalrCount(*arg_pge_evals)
|
---|
201 |
|
---|
202 | } else {
|
---|
203 | log.Fatalf("unknown config type: %v from %v\n", cfg[:4], cfg)
|
---|
204 | }
|
---|
205 | }
|
---|
206 |
|
---|
207 | // setup best results
|
---|
208 | DS.eqns = make(probs.ExprReportArray, 32)
|
---|
209 | DS.per_eqns = make([]*probs.ExprReportArray, len(DS.srch))
|
---|
210 |
|
---|
211 | // setup communication struct
|
---|
212 | DS.comm = make([]*probs.ExprProblemComm, len(DS.srch))
|
---|
213 | for i, _ := range DS.comm {
|
---|
214 | DS.comm[i] = new(probs.ExprProblemComm)
|
---|
215 | DS.comm[i].Cmds = make(chan int)
|
---|
216 | DS.comm[i].Rpts = make(chan *probs.ExprReportArray, 64)
|
---|
217 | DS.comm[i].Gen = make(chan [2]int, 64)
|
---|
218 | }
|
---|
219 |
|
---|
220 | DS.iter = make([]int, len(DS.srch))
|
---|
221 |
|
---|
222 | fmt.Println("\n******************************************************\n")
|
---|
223 |
|
---|
224 | // initialize searches
|
---|
225 | sdone := make(chan int)
|
---|
226 | for i, _ := range DS.srch {
|
---|
227 | DS.srch[i].Init(sdone, eprob, DC.logDir, DS.comm[i])
|
---|
228 | }
|
---|
229 | fmt.Println("\n******************************************************\n")
|
---|
230 |
|
---|
231 | }
|
---|
232 |
|
---|
233 | func (DS *MainSearch) Run() {
|
---|
234 | fmt.Printf("Running Main\n")
|
---|
235 | fmt.Println("numSrch = ", len(DS.srch))
|
---|
236 | for i := 0; i < len(DS.srch); i++ {
|
---|
237 | go DS.srch[i].Run()
|
---|
238 | }
|
---|
239 | counter := 0
|
---|
240 | for {
|
---|
241 | // fmt.Println("DS: ", counter)
|
---|
242 | DS.checkMessages()
|
---|
243 |
|
---|
244 | // time.Sleep(time.Second / 20)
|
---|
245 | counter++
|
---|
246 |
|
---|
247 | if DS.checkStop() {
|
---|
248 | DS.doStop()
|
---|
249 | break
|
---|
250 | }
|
---|
251 | }
|
---|
252 |
|
---|
253 | for i, R := range DS.eqns {
|
---|
254 | if R == nil || R.Expr() == nil {
|
---|
255 | continue
|
---|
256 | }
|
---|
257 | trn := DS.prob.Train[0]
|
---|
258 | f_x := "df(" + trn.GetIndepNames()[DS.prob.SearchVar] + ")"
|
---|
259 | str := R.Expr().PrettyPrint(trn.GetIndepNames(), trn.GetSysNames(), R.Coeff())
|
---|
260 | fmt.Printf("%d: %s = %s\n%v\n\n", i, f_x, str, R)
|
---|
261 | }
|
---|
262 |
|
---|
263 | DS.Clean()
|
---|
264 |
|
---|
265 | fmt.Println("DS leaving Run()")
|
---|
266 | }
|
---|
267 |
|
---|
268 | func (DS *MainSearch) Clean() {
|
---|
269 | fmt.Printf("Cleaning Main\n")
|
---|
270 |
|
---|
271 | DS.errLogBuf.Flush()
|
---|
272 | DS.mainLogBuf.Flush()
|
---|
273 | DS.eqnsLogBuf.Flush()
|
---|
274 |
|
---|
275 | }
|
---|
276 |
|
---|
277 | func (DS *MainSearch) checkStop() bool {
|
---|
278 | if DS.iter[0] > DS.prob.MaxIter {
|
---|
279 | return true
|
---|
280 | }
|
---|
281 | return false
|
---|
282 | }
|
---|
283 |
|
---|
284 | func (DS *MainSearch) doStop() {
|
---|
285 | done := make(chan int)
|
---|
286 |
|
---|
287 | for i, _ := range DS.comm {
|
---|
288 | func() {
|
---|
289 | c := i
|
---|
290 | C := DS.comm[i]
|
---|
291 | go func() {
|
---|
292 | C.Cmds <- -1
|
---|
293 | fmt.Printf("DS sent -1 to Srch %d\n", c)
|
---|
294 | <-C.Cmds
|
---|
295 | done <- 1
|
---|
296 | }()
|
---|
297 | }()
|
---|
298 | }
|
---|
299 |
|
---|
300 | cnt := 0
|
---|
301 | for cnt < len(DS.comm) {
|
---|
302 | DS.checkMessages()
|
---|
303 | _, ok := <-done
|
---|
304 | if ok {
|
---|
305 | cnt++
|
---|
306 | fmt.Println("DS done = ", cnt, len(DS.comm))
|
---|
307 | }
|
---|
308 |
|
---|
309 | }
|
---|
310 |
|
---|
311 | fmt.Println("DAMD checking last messages")
|
---|
312 | DS.checkMessages()
|
---|
313 |
|
---|
314 | fmt.Println("DS done stopping")
|
---|
315 | }
|
---|
316 |
|
---|
317 | func (DS *MainSearch) checkMessages() {
|
---|
318 | msg := false
|
---|
319 | for i := 0; i < len(DS.comm); i++ {
|
---|
320 | select {
|
---|
321 | case gen, ok := <-DS.comm[i].Gen:
|
---|
322 | if ok {
|
---|
323 | DS.iter[i] = gen[1]
|
---|
324 | if gen[0] == 0 {
|
---|
325 | fmt.Println("Gen: ", gen[1])
|
---|
326 | }
|
---|
327 | i--
|
---|
328 | msg = true
|
---|
329 | }
|
---|
330 | case rpt, ok := <-DS.comm[i].Rpts:
|
---|
331 | if ok {
|
---|
332 | msg = true
|
---|
333 | DS.per_eqns[i] = rpt
|
---|
334 | i--
|
---|
335 | }
|
---|
336 | default:
|
---|
337 | continue
|
---|
338 | }
|
---|
339 | }
|
---|
340 | if !msg {
|
---|
341 | time.Sleep(time.Millisecond)
|
---|
342 | }
|
---|
343 | DS.accumExprs()
|
---|
344 | }
|
---|
345 |
|
---|
346 | func (DS *MainSearch) accumExprs() {
|
---|
347 | union := make(probs.ExprReportArray, 0)
|
---|
348 | for i := 0; i < len(DS.per_eqns); i++ {
|
---|
349 | if DS.per_eqns[i] != nil {
|
---|
350 | union = append(union, (*DS.per_eqns[i])[:]...)
|
---|
351 | }
|
---|
352 | }
|
---|
353 | union = append(union, DS.eqns[:]...)
|
---|
354 |
|
---|
355 | // remove duplicates
|
---|
356 | sort.Sort(union)
|
---|
357 | last := 0
|
---|
358 | for last < len(union) && union[last] == nil {
|
---|
359 | last++
|
---|
360 | }
|
---|
361 | for i := last + 1; i < len(union); i++ {
|
---|
362 | if union[i] == nil {
|
---|
363 | continue
|
---|
364 | }
|
---|
365 | if union[i].Expr().AmIAlmostSame(union[last].Expr()) {
|
---|
366 | union[i] = nil
|
---|
367 | } else {
|
---|
368 | last = i
|
---|
369 | }
|
---|
370 | }
|
---|
371 |
|
---|
372 | queue := probs.NewQueueFromArray(union)
|
---|
373 | queue.SetSort(probs.GPSORT_PARETO_TST_ERR)
|
---|
374 | queue.Sort()
|
---|
375 |
|
---|
376 | copy(DS.eqns, union[:len(DS.eqns)])
|
---|
377 |
|
---|
378 | // DS.eqnsLog.Printf("\n\n\nLatest Eqns:\n")
|
---|
379 | // DS.eqnsLog.Println(DS.eqns)
|
---|
380 | // DS.Clean()
|
---|
381 |
|
---|
382 | }
|
---|
383 |
|
---|
384 | func (DS *MainSearch) initLogs(logdir string) {
|
---|
385 |
|
---|
386 | // open logs
|
---|
387 | DS.logDir = logdir
|
---|
388 | os.Mkdir(DS.logDir, os.ModePerm)
|
---|
389 | tmpF0, err5 := os.Create(DS.logDir + "main:err.log")
|
---|
390 | if err5 != nil {
|
---|
391 | log.Fatal("couldn't create errs log", err5)
|
---|
392 | }
|
---|
393 | DS.errLogBuf = bufio.NewWriter(tmpF0)
|
---|
394 | DS.errLogBuf.Flush()
|
---|
395 | DS.errLog = log.New(DS.errLogBuf, "", log.LstdFlags)
|
---|
396 |
|
---|
397 | tmpF1, err1 := os.Create(DS.logDir + "main:main.log")
|
---|
398 | if err1 != nil {
|
---|
399 | log.Fatal("couldn't create main log", err1)
|
---|
400 | }
|
---|
401 | DS.mainLogBuf = bufio.NewWriter(tmpF1)
|
---|
402 | DS.mainLogBuf.Flush()
|
---|
403 | DS.mainLog = log.New(DS.mainLogBuf, "", log.LstdFlags)
|
---|
404 |
|
---|
405 | tmpF2, err2 := os.Create(DS.logDir + "main:eqns.log")
|
---|
406 | if err2 != nil {
|
---|
407 | log.Fatal("couldn't create eqns log", err2)
|
---|
408 | }
|
---|
409 | DS.eqnsLogBuf = bufio.NewWriter(tmpF2)
|
---|
410 | DS.eqnsLogBuf.Flush()
|
---|
411 | DS.eqnsLog = log.New(DS.eqnsLogBuf, "", log.LstdFlags)
|
---|
412 |
|
---|
413 | }
|
---|