Changeset 15973 for branches/2522_RefactorPluginInfrastructure/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs
- Timestamp:
- 06/28/18 11:13:37 (6 years ago)
- Location:
- branches/2522_RefactorPluginInfrastructure
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2522_RefactorPluginInfrastructure
- Property svn:ignore
-
old new 24 24 protoc.exe 25 25 obj 26 .vs
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Tests
- Property svn:mergeinfo changed
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs
r13157 r15973 1 1 using System; 2 using System.Collections; 3 using System.IO; 2 4 using System.Linq; 3 using System.Threading;4 5 using HeuristicLab.Data; 5 using HeuristicLab.Optimization;6 6 using HeuristicLab.Problems.DataAnalysis; 7 7 using Microsoft.VisualStudio.TestTools.UnitTesting; … … 165 165 [TestMethod] 166 166 [TestCategory("Algorithms.DataAnalysis")] 167 [TestProperty("Time", "short")] 168 public void TestDecisionTreePartialDependence() { 169 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 170 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 171 var regProblem = new RegressionProblem(); 172 regProblem.Load(provider.LoadData(instance)); 173 var problemData = regProblem.ProblemData; 174 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02); 175 for (int i = 0; i < 1000; i++) 176 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 177 178 179 var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First(); 180 Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value); 181 var model = ((IGradientBoostedTreesModel)state.GetModel()); 182 var treeM = model.Models.Skip(1).First(); 183 Console.WriteLine(treeM.ToString()); 184 Console.WriteLine(); 185 186 var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray(); 187 var ds = new ModifiableDataset(new string[] { mostImportantVar.Key }, 188 new IList[] { mostImportantVarValues.ToList<double>() }); 189 190 var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray(); 191 192 for (int i = 0; i < mostImportantVarValues.Length; i += 10) { 193 Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]); 194 } 195 } 196 197 [TestMethod] 198 [TestCategory("Algorithms.DataAnalysis")] 199 [TestProperty("Time", "short")] 200 public void TestDecisionTreePersistence() { 201 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 202 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 203 var regProblem = new RegressionProblem(); 204 regProblem.Load(provider.LoadData(instance)); 205 var problemData = regProblem.ProblemData; 206 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1); 207 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 208 209 var model = ((IGradientBoostedTreesModel)state.GetModel()); 210 var treeM = model.Models.Skip(1).First(); 211 var origStr = treeM.ToString(); 212 using (var memStream = new MemoryStream()) { 213 Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream); 214 var buf = memStream.GetBuffer(); 215 using (var restoreStream = new MemoryStream(buf)) { 216 var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream); 217 var restoredStr = restoredTree.ToString(); 218 Assert.AreEqual(origStr, restoredStr); 219 } 220 } 221 } 222 223 [TestMethod] 224 [TestCategory("Algorithms.DataAnalysis")] 167 225 [TestProperty("Time", "long")] 168 226 public void GradientBoostingTestTowerSquaredError() { … … 182 240 #endregion 183 241 184 RunAlgorithm(gbt);242 gbt.Start(); 185 243 186 244 Console.WriteLine(gbt.ExecutionTime); … … 210 268 #endregion 211 269 212 RunAlgorithm(gbt);270 gbt.Start(); 213 271 214 272 Console.WriteLine(gbt.ExecutionTime); … … 238 296 #endregion 239 297 240 RunAlgorithm(gbt);298 gbt.Start(); 241 299 242 300 Console.WriteLine(gbt.ExecutionTime); 243 301 Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6); 244 302 Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6); 245 }246 247 // same as in SamplesUtil248 private void RunAlgorithm(IAlgorithm a) {249 var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);250 Exception ex = null;251 a.Stopped += (src, e) => { trigger.Set(); };252 a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };253 a.Prepare();254 a.Start();255 trigger.WaitOne();256 257 Assert.AreEqual(ex, null);258 303 } 259 304
Note: See TracChangeset
for help on using the changeset viewer.