Changeset 12874 for branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
- Timestamp:
- 08/17/15 18:38:17 (9 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
/trunk/sources merged: 12873
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 12873
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12872 r12874 82 82 get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 83 83 } 84 public IConstrainedValueParameter< StringValue> LossFunctionParameter {85 get { return (IConstrainedValueParameter< StringValue>)Parameters[LossFunctionParameterName]; }84 public IConstrainedValueParameter<ILossFunction> LossFunctionParameter { 85 get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; } 86 86 } 87 87 public IFixedValueParameter<IntValue> UpdateIntervalParameter { … … 164 164 Parameters[CreateSolutionParameterName].Hidden = true; 165 165 166 var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly()); 167 Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames))); 168 LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default 169 } 170 166 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 167 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 168 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default 169 } 170 171 [StorableHook(HookType.AfterDeserialization)] 172 private void AfterDeserialization() { 173 // BackwardsCompatibility3.4 174 #region Backwards compatible code, remove with 3.5 175 // parameter type has been changed 176 var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>; 177 if (lossFunctionParam != null) { 178 Parameters.Remove(LossFunctionParameterName); 179 var selectedValue = lossFunctionParam.Value; // to be restored below 180 181 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 182 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 183 // try to restore selected value 184 var selectedLossFunction = 185 LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value); 186 if (selectedLossFunction != null) { 187 LossFunctionParameter.Value = selectedLossFunction; 188 } else { 189 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE 190 } 191 } 192 #endregion 193 } 171 194 172 195 protected override void Run(CancellationToken cancellationToken) { … … 187 210 // init 188 211 var problemData = (IRegressionProblemData)Problem.ProblemData.Clone(); 189 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>() 190 .Single(l => l.ToString() == LossFunctionParameter.Value.Value); 212 var lossFunction = LossFunctionParameter.Value; 191 213 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu); 192 214 … … 233 255 // produce solution 234 256 if (CreateSolution) { 235 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction .ToString(),257 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, 236 258 Iterations, MaxSize, R, M, Nu, state.GetModel()); 237 259
Note: See TracChangeset
for help on using the changeset viewer.