- Timestamp:
- 06/15/16 10:02:15 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs
r13157 r13895 1 1 using System; 2 using System.Collections; 3 using System.IO; 2 4 using System.Linq; 3 5 using System.Threading; … … 160 162 // x2 > 1.5 AND x1 > 1.5 -> 3.0 161 163 BuildTree(xy, allVariables, 10); 164 } 165 } 166 167 [TestMethod] 168 [TestCategory("Algorithms.DataAnalysis")] 169 [TestProperty("Time", "short")] 170 public void TestDecisionTreePartialDependence() { 171 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 172 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 173 var regProblem = new RegressionProblem(); 174 regProblem.Load(provider.LoadData(instance)); 175 var problemData = regProblem.ProblemData; 176 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02); 177 for (int i = 0; i < 1000; i++) 178 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 179 180 181 var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First(); 182 Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value); 183 var model = ((IGradientBoostedTreesModel)state.GetModel()); 184 var treeM = model.Models.Skip(1).First(); 185 Console.WriteLine(treeM.ToString()); 186 Console.WriteLine(); 187 188 var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray(); 189 var ds = new ModifiableDataset(new string[] { mostImportantVar.Key }, 190 new IList[] { mostImportantVarValues.ToList<double>() }); 191 192 var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray(); 193 194 for (int i = 0; i < mostImportantVarValues.Length; i += 10) { 195 Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]); 196 } 197 } 198 199 [TestMethod] 200 [TestCategory("Algorithms.DataAnalysis")] 201 [TestProperty("Time", "short")] 202 public void TestDecisionTreePersistence() { 203 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 204 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 205 var regProblem = new RegressionProblem(); 206 regProblem.Load(provider.LoadData(instance)); 207 var problemData = regProblem.ProblemData; 208 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1); 209 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 210 211 var model = ((IGradientBoostedTreesModel)state.GetModel()); 212 var treeM = model.Models.Skip(1).First(); 213 var origStr = treeM.ToString(); 214 using (var memStream = new MemoryStream()) { 215 Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream); 216 var buf = memStream.GetBuffer(); 217 using (var restoreStream = new MemoryStream(buf)) { 218 var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream); 219 var restoredStr = restoredTree.ToString(); 220 Assert.AreEqual(origStr, restoredStr); 221 } 162 222 } 163 223 }
Note: See TracChangeset
for help on using the changeset viewer.