source: trunk/sources/HeuristicLab.Tests/HeuristicLab.Scripting-3.3/GridSearchScriptTest.cs @ 11483

Last change on this file since 11483 was 11483, checked in by bburlacu, 6 years ago

#2211: Updated script unit tests to wait for the script thread to finish before validating results.

File size: 7.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.CodeDom.Compiler;
24using System.IO;
25using System.Linq;
26using System.Reflection;
27using System.Threading;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Optimizer;
31using HeuristicLab.Persistence.Default.Xml;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Scripting;
34using Microsoft.VisualStudio.TestTools.UnitTesting;
35
36namespace HeuristicLab.Tests {
37  [TestClass]
38  public class GridSearchScriptTest {
39    private const string PathPrefix = "HeuristicLab.Optimizer.Documents.";
40    private const string PathSuffix = ".hl";
41    private const string SvmClassificationScriptName = "GridSearch_SVM_Classification";
42    private const string SvmRegressionScriptName = "GridSearch_SVM_Regression";
43    private const string RandomForestRegressionScriptName = "GridSearch_RF_Regression";
44    private const string RandomForestClassificationScriptName = "GridSearch_RF_Classification";
45    private const string SamplesDirectory = SamplesUtils.Directory;
46
47    private readonly ManualResetEvent manualResetEvent = new ManualResetEvent(false);
48
49    [ClassInitialize]
50    public static void MyClassInitialize(TestContext testContext) {
51      PluginLoader.Assemblies.Any();
52      if (!Directory.Exists(SamplesDirectory))
53        Directory.CreateDirectory(SamplesDirectory);
54    }
55
56    [TestMethod]
57    [TestCategory("Scripting")]
58    public void RunRandomForestRegressionScriptTest() {
59      var assembly = new StartPage().GetType().Assembly;
60      const string name = PathPrefix + RandomForestRegressionScriptName + PathSuffix;
61      var script = (CSharpScript)LoadSample(name, assembly);
62
63      try {
64        script.Compile();
65      }
66      catch {
67        if (script.CompileErrors.HasErrors) {
68          ShowCompilationResults(script);
69          throw new Exception("Compilation failed.");
70        } else {
71          Console.WriteLine("Compilation succeeded.");
72        }
73      }
74      finally {
75        script.ScriptExecutionFinished += script_ExecutionFinished;
76        script.Execute();
77        var vs = script.VariableStore;
78        var solution = (IRegressionSolution)vs["demo_bestSolution"];
79        Assert.IsTrue(solution.TrainingRSquared.IsAlmost(1));
80      }
81    }
82
83    [TestMethod]
84    [TestCategory("Scripting")]
85    public void RunRandomForestClassificationScriptTest() {
86      var assembly = new StartPage().GetType().Assembly;
87      const string name = PathPrefix + RandomForestClassificationScriptName + PathSuffix;
88      var script = (CSharpScript)LoadSample(name, assembly);
89
90      try {
91        script.Compile();
92      }
93      catch {
94        if (script.CompileErrors.HasErrors) {
95          ShowCompilationResults(script);
96          throw new Exception("Compilation failed.");
97        } else {
98          Console.WriteLine("Compilation succeeded.");
99        }
100      }
101      finally {
102        script.ScriptExecutionFinished += script_ExecutionFinished;
103        script.Execute();
104        var vs = script.VariableStore;
105        var solution = (IClassificationSolution)vs["demo_bestSolution"];
106        Assert.IsTrue(solution.TrainingAccuracy.IsAlmost(1) && solution.TestAccuracy.IsAlmost(0.953125));
107      }
108    }
109
110    [TestMethod]
111    [TestCategory("Scripting")]
112    public void RunSvmRegressionScriptTest() {
113      var assembly = new StartPage().GetType().Assembly;
114      const string name = PathPrefix + SvmRegressionScriptName + PathSuffix;
115      var script = (CSharpScript)LoadSample(name, assembly);
116
117      try {
118        script.Compile();
119      }
120      catch {
121        if (script.CompileErrors.HasErrors) {
122          ShowCompilationResults(script);
123          throw new Exception("Compilation failed.");
124        } else {
125          Console.WriteLine("Compilation succeeded.");
126        }
127      }
128      finally {
129        script.ScriptExecutionFinished += script_ExecutionFinished;
130        script.Execute();
131        var vs = script.VariableStore;
132        var solution = (IRegressionSolution)vs["demo_bestSolution"];
133        Assert.IsTrue(solution.TrainingRSquared.IsAlmost(0.066221959224331) && solution.TestRSquared.IsAlmost(0.0794407638195883));
134      }
135    }
136
137    [TestMethod]
138    [TestCategory("Scripting")]
139    public void RunSvmClassificationScriptTest() {
140      var assembly = new StartPage().GetType().Assembly;
141      const string name = PathPrefix + SvmClassificationScriptName + PathSuffix;
142      var script = (CSharpScript)LoadSample(name, assembly);
143
144      try {
145        script.Compile();
146      }
147      catch {
148        if (script.CompileErrors.HasErrors) {
149          ShowCompilationResults(script);
150          throw new Exception("Compilation failed.");
151        } else {
152          Console.WriteLine("Compilation succeeded.");
153        }
154      }
155      finally {
156        script.ScriptExecutionFinished += script_ExecutionFinished;
157        script.Execute();
158        manualResetEvent.WaitOne();
159        var vs = script.VariableStore;
160        var solution = (IClassificationSolution)vs["demo_bestSolution"];
161        Assert.IsTrue(solution.TrainingAccuracy.IsAlmost(0.817472698907956) && solution.TestAccuracy.IsAlmost(0.809375));
162      }
163    }
164
165    #region Helpers
166    private static void ShowCompilationResults(Script script) {
167      if (script.CompileErrors.Count == 0) return;
168      var msgs = script.CompileErrors.OfType<CompilerError>()
169                                      .OrderBy(x => x.IsWarning)
170                                      .ThenBy(x => x.Line)
171                                      .ThenBy(x => x.Column);
172      foreach (var m in msgs) {
173        Console.WriteLine(m);
174      }
175    }
176
177    private INamedItem LoadSample(string name, Assembly assembly) {
178      string path = Path.GetTempFileName();
179      INamedItem item = null;
180      try {
181        using (var stream = assembly.GetManifestResourceStream(name)) {
182          WriteStreamToTempFile(stream, path); // create a file in a temporary folder (persistence cannot load these files directly from the stream)
183          item = XmlParser.Deserialize<INamedItem>(path);
184        }
185      }
186      catch (Exception) {
187      }
188      finally {
189        if (File.Exists(path)) {
190          File.Delete(path); // make sure we remove the temporary file
191        }
192      }
193      return item;
194    }
195
196    private void WriteStreamToTempFile(Stream stream, string path) {
197      using (FileStream output = new FileStream(path, FileMode.Create, FileAccess.Write)) {
198        stream.CopyTo(output);
199      }
200    }
201
202    private void script_ExecutionFinished(object sender, EventArgs a) {
203      manualResetEvent.Set();
204    }
205    #endregion
206  }
207}
Note: See TracBrowser for help on using the repository browser.