Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassification.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassification.cs (revision 8623)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassification.cs (revision 8623)
@@ -0,0 +1,196 @@
+
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using System;
+using HeuristicLab.Algorithms.GradientDescent;
+using HeuristicLab.Common;
+using HeuristicLab.Core;
+using HeuristicLab.Data;
+using HeuristicLab.Operators;
+using HeuristicLab.Optimization;
+using HeuristicLab.Parameters;
+using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis {
+ ///
+ /// Gaussian process least-squares classification data analysis algorithm.
+ ///
+ [Item("Gaussian Process Least-Squares Classification", "Gaussian process least-squares classification data analysis algorithm.")]
+ [Creatable("Data Analysis")]
+ [StorableClass]
+ public sealed class GaussianProcessClassification : EngineAlgorithm, IStorableContent {
+ public string Filename { get; set; }
+
+ public override Type ProblemType { get { return typeof(IClassificationProblem); } }
+ public new IClassificationProblem Problem {
+ get { return (IClassificationProblem)base.Problem; }
+ set { base.Problem = value; }
+ }
+
+ private const string MeanFunctionParameterName = "MeanFunction";
+ private const string CovarianceFunctionParameterName = "CovarianceFunction";
+ private const string MinimizationIterationsParameterName = "Iterations";
+ private const string ApproximateGradientsParameterName = "ApproximateGradients";
+ private const string SeedParameterName = "Seed";
+ private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
+
+ #region parameter properties
+ public IValueParameter MeanFunctionParameter {
+ get { return (IValueParameter)Parameters[MeanFunctionParameterName]; }
+ }
+ public IValueParameter CovarianceFunctionParameter {
+ get { return (IValueParameter)Parameters[CovarianceFunctionParameterName]; }
+ }
+ public IValueParameter MinimizationIterationsParameter {
+ get { return (IValueParameter)Parameters[MinimizationIterationsParameterName]; }
+ }
+ public IValueParameter SeedParameter {
+ get { return (IValueParameter)Parameters[SeedParameterName]; }
+ }
+ public IValueParameter SetSeedRandomlyParameter {
+ get { return (IValueParameter)Parameters[SetSeedRandomlyParameterName]; }
+ }
+ #endregion
+ #region properties
+ public IMeanFunction MeanFunction {
+ set { MeanFunctionParameter.Value = value; }
+ get { return MeanFunctionParameter.Value; }
+ }
+ public ICovarianceFunction CovarianceFunction {
+ set { CovarianceFunctionParameter.Value = value; }
+ get { return CovarianceFunctionParameter.Value; }
+ }
+ public int MinimizationIterations {
+ set { MinimizationIterationsParameter.Value.Value = value; }
+ get { return MinimizationIterationsParameter.Value.Value; }
+ }
+ public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
+ public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
+ #endregion
+
+ [StorableConstructor]
+ private GaussianProcessClassification(bool deserializing) : base(deserializing) { }
+ private GaussianProcessClassification(GaussianProcessClassification original, Cloner cloner)
+ : base(original, cloner) {
+ }
+ public GaussianProcessClassification()
+ : base() {
+ this.name = ItemName;
+ this.description = ItemDescription;
+
+ Problem = new ClassificationProblem();
+
+ Parameters.Add(new ValueParameter(MeanFunctionParameterName, "The mean function to use.", new MeanConst()));
+ Parameters.Add(new ValueParameter(CovarianceFunctionParameterName, "The covariance function to use.", new CovarianceSquaredExponentialIso()));
+ Parameters.Add(new ValueParameter(MinimizationIterationsParameterName, "The number of iterations for likelihood optimization with LM-BFGS.", new IntValue(20)));
+ Parameters.Add(new ValueParameter(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
+ Parameters.Add(new ValueParameter(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
+
+ Parameters.Add(new ValueParameter(ApproximateGradientsParameterName, "Indicates that gradients should not be approximated (necessary for LM-BFGS).", new BoolValue(false)));
+ Parameters[ApproximateGradientsParameterName].Hidden = true; // should not be changed
+
+ var randomCreator = new HeuristicLab.Random.RandomCreator();
+ var gpInitializer = new GaussianProcessHyperparameterInitializer();
+ var bfgsInitializer = new LbfgsInitializer();
+ var makeStep = new LbfgsMakeStep();
+ var branch = new ConditionalBranch();
+ var modelCreator = new GaussianProcessClassificationModelCreator();
+ var updateResults = new LbfgsUpdateResults();
+ var analyzer = new LbfgsAnalyzer();
+ var finalModelCreator = new GaussianProcessClassificationModelCreator();
+ var finalAnalyzer = new LbfgsAnalyzer();
+ var solutionCreator = new GaussianProcessClassificationSolutionCreator();
+
+ OperatorGraph.InitialOperator = randomCreator;
+ randomCreator.SeedParameter.ActualName = SeedParameterName;
+ randomCreator.SeedParameter.Value = null;
+ randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
+ randomCreator.SetSeedRandomlyParameter.Value = null;
+ randomCreator.Successor = gpInitializer;
+
+ gpInitializer.CovarianceFunctionParameter.ActualName = CovarianceFunctionParameterName;
+ gpInitializer.MeanFunctionParameter.ActualName = MeanFunctionParameterName;
+ gpInitializer.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
+ gpInitializer.HyperparameterParameter.ActualName = modelCreator.HyperparameterParameter.Name;
+ gpInitializer.RandomParameter.ActualName = randomCreator.RandomParameter.Name;
+ gpInitializer.Successor = bfgsInitializer;
+
+ bfgsInitializer.IterationsParameter.ActualName = MinimizationIterationsParameterName;
+ bfgsInitializer.PointParameter.ActualName = modelCreator.HyperparameterParameter.Name;
+ bfgsInitializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
+ bfgsInitializer.Successor = makeStep;
+
+ makeStep.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
+ makeStep.PointParameter.ActualName = modelCreator.HyperparameterParameter.Name;
+ makeStep.Successor = branch;
+
+ branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
+ branch.FalseBranch = modelCreator;
+ branch.TrueBranch = finalModelCreator;
+
+ modelCreator.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
+ modelCreator.MeanFunctionParameter.ActualName = MeanFunctionParameterName;
+ modelCreator.CovarianceFunctionParameter.ActualName = CovarianceFunctionParameterName;
+ modelCreator.Successor = updateResults;
+
+ updateResults.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
+ updateResults.QualityParameter.ActualName = modelCreator.NegativeLogLikelihoodParameter.Name;
+ updateResults.QualityGradientsParameter.ActualName = modelCreator.HyperparameterGradientsParameter.Name;
+ updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
+ updateResults.Successor = analyzer;
+
+ analyzer.QualityParameter.ActualName = modelCreator.NegativeLogLikelihoodParameter.Name;
+ analyzer.PointParameter.ActualName = modelCreator.HyperparameterParameter.Name;
+ analyzer.QualityGradientsParameter.ActualName = modelCreator.HyperparameterGradientsParameter.Name;
+ analyzer.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
+ analyzer.PointsTableParameter.ActualName = "Hyperparameter table";
+ analyzer.QualityGradientsTableParameter.ActualName = "Gradients table";
+ analyzer.QualitiesTableParameter.ActualName = "Negative log likelihood table";
+ analyzer.Successor = makeStep;
+
+ finalModelCreator.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
+ finalModelCreator.MeanFunctionParameter.ActualName = MeanFunctionParameterName;
+ finalModelCreator.CovarianceFunctionParameter.ActualName = CovarianceFunctionParameterName;
+ finalModelCreator.HyperparameterParameter.ActualName = bfgsInitializer.PointParameter.ActualName;
+ finalModelCreator.Successor = finalAnalyzer;
+
+ finalAnalyzer.QualityParameter.ActualName = modelCreator.NegativeLogLikelihoodParameter.Name;
+ finalAnalyzer.PointParameter.ActualName = modelCreator.HyperparameterParameter.Name;
+ finalAnalyzer.QualityGradientsParameter.ActualName = modelCreator.HyperparameterGradientsParameter.Name;
+ finalAnalyzer.PointsTableParameter.ActualName = analyzer.PointsTableParameter.ActualName;
+ finalAnalyzer.QualityGradientsTableParameter.ActualName = analyzer.QualityGradientsTableParameter.ActualName;
+ finalAnalyzer.QualitiesTableParameter.ActualName = analyzer.QualitiesTableParameter.ActualName;
+ finalAnalyzer.Successor = solutionCreator;
+
+ solutionCreator.ModelParameter.ActualName = finalModelCreator.ModelParameter.Name;
+ solutionCreator.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
+ }
+
+ [StorableHook(HookType.AfterDeserialization)]
+ private void AfterDeserialization() { }
+
+ public override IDeepCloneable Clone(Cloner cloner) {
+ return new GaussianProcessClassification(this, cloner);
+ }
+ }
+}
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationModelCreator.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationModelCreator.cs (revision 8623)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationModelCreator.cs (revision 8623)
@@ -0,0 +1,81 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using System;
+using System.Linq;
+using HeuristicLab.Common;
+using HeuristicLab.Core;
+using HeuristicLab.Data;
+using HeuristicLab.Encodings.RealVectorEncoding;
+using HeuristicLab.Parameters;
+using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis {
+ [StorableClass]
+ [Item(Name = "GaussianProcessClassificationModelCreator",
+ Description = "Creates a Gaussian process model for least-squares classification given the data, the hyperparameters, a mean function, and a covariance function.")]
+ public sealed class GaussianProcessClassificationModelCreator : GaussianProcessModelCreator {
+ private const string ProblemDataParameterName = "ProblemData";
+
+ #region Parameter Properties
+ public ILookupParameter ProblemDataParameter {
+ get { return (ILookupParameter)Parameters[ProblemDataParameterName]; }
+ }
+ #endregion
+
+ #region Properties
+ private IClassificationProblemData ProblemData {
+ get { return ProblemDataParameter.ActualValue; }
+ }
+ #endregion
+ [StorableConstructor]
+ private GaussianProcessClassificationModelCreator(bool deserializing) : base(deserializing) { }
+ private GaussianProcessClassificationModelCreator(GaussianProcessClassificationModelCreator original, Cloner cloner) : base(original, cloner) { }
+ public GaussianProcessClassificationModelCreator()
+ : base() {
+ Parameters.Add(new LookupParameter(ProblemDataParameterName, "The classification problem data for the Gaussian process model."));
+ }
+
+ public override IDeepCloneable Clone(Cloner cloner) {
+ return new GaussianProcessClassificationModelCreator(this, cloner);
+ }
+
+ public override IOperation Apply() {
+ try {
+ var model = Create(ProblemData, Hyperparameter.ToArray(), MeanFunction, CovarianceFunction);
+ ModelParameter.ActualValue = model;
+ NegativeLogLikelihoodParameter.ActualValue = new DoubleValue(model.NegativeLogLikelihood);
+ HyperparameterGradientsParameter.ActualValue = new RealVector(model.HyperparameterGradients);
+ return base.Apply();
+ }
+ catch (ArgumentException) { }
+ catch (alglib.alglibexception) { }
+ NegativeLogLikelihoodParameter.ActualValue = new DoubleValue(1E300);
+ HyperparameterGradientsParameter.ActualValue = new RealVector(Hyperparameter.Count());
+ return base.Apply();
+ }
+
+ public static IGaussianProcessModel Create(IClassificationProblemData problemData, double[] hyperparameter, IMeanFunction meanFunction, ICovarianceFunction covarianceFunction) {
+ return new GaussianProcessModel(problemData.Dataset, problemData.TargetVariable, problemData.AllowedInputVariables, problemData.TrainingIndices, hyperparameter, meanFunction, covarianceFunction);
+ }
+ }
+}
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationSolutionCreator.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationSolutionCreator.cs (revision 8623)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessClassificationSolutionCreator.cs (revision 8623)
@@ -0,0 +1,103 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using HeuristicLab.Common;
+using HeuristicLab.Core;
+using HeuristicLab.Data;
+using HeuristicLab.Operators;
+using HeuristicLab.Optimization;
+using HeuristicLab.Parameters;
+using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis {
+ [StorableClass]
+ [Item(Name = "GaussianProcessClassificationSolutionCreator",
+ Description = "Creates a Gaussian process solution from a trained model.")]
+ public sealed class GaussianProcessClassificationSolutionCreator : SingleSuccessorOperator {
+ private const string ProblemDataParameterName = "ProblemData";
+ private const string ModelParameterName = "GaussianProcessClassificationModel";
+ private const string SolutionParameterName = "Solution";
+ private const string ResultsParameterName = "Results";
+ private const string TrainingAccuracyResultName = "Accuracy (training)";
+ private const string TestAccuracyResultName = "Accuracy (test)";
+
+ #region Parameter Properties
+ public ILookupParameter ProblemDataParameter {
+ get { return (ILookupParameter)Parameters[ProblemDataParameterName]; }
+ }
+ public ILookupParameter SolutionParameter {
+ get { return (ILookupParameter)Parameters[SolutionParameterName]; }
+ }
+ public ILookupParameter ModelParameter {
+ get { return (ILookupParameter)Parameters[ModelParameterName]; }
+ }
+ public ILookupParameter ResultsParameter {
+ get { return (ILookupParameter)Parameters[ResultsParameterName]; }
+ }
+ #endregion
+
+ [StorableConstructor]
+ private GaussianProcessClassificationSolutionCreator(bool deserializing) : base(deserializing) { }
+ private GaussianProcessClassificationSolutionCreator(GaussianProcessClassificationSolutionCreator original, Cloner cloner) : base(original, cloner) { }
+ public GaussianProcessClassificationSolutionCreator()
+ : base() {
+ // in
+ Parameters.Add(new LookupParameter(ProblemDataParameterName, "The classification problem data for the Gaussian process solution."));
+ Parameters.Add(new LookupParameter(ModelParameterName, "The Gaussian process classification model to use for the solution."));
+ // in & out
+ Parameters.Add(new LookupParameter(ResultsParameterName, "The result collection of the algorithm."));
+ // out
+ Parameters.Add(new LookupParameter(SolutionParameterName, "The produced Gaussian process solution."));
+ }
+
+ public override IDeepCloneable Clone(Cloner cloner) {
+ return new GaussianProcessClassificationSolutionCreator(this, cloner);
+ }
+
+ public override IOperation Apply() {
+ if (ModelParameter.ActualValue != null) {
+ var m = (IGaussianProcessModel)ModelParameter.ActualValue.Clone();
+ var data = (IClassificationProblemData)ProblemDataParameter.ActualValue.Clone();
+ var model = new GaussianProcessDiscriminantFunctionClassificationModel(m, new NormalDistributionCutPointsThresholdCalculator());
+ model.RecalculateModelParameters(data, data.TrainingIndices);
+ var s = model.CreateDiscriminantFunctionClassificationSolution(data);
+
+ SolutionParameter.ActualValue = s;
+ var results = ResultsParameter.ActualValue;
+ if (!results.ContainsKey(SolutionParameterName)) {
+ results.Add(new Result(SolutionParameterName, "The Gaussian process classification solution", s));
+ results.Add(new Result(TrainingAccuracyResultName,
+ "The accuracy of the Gaussian process solution on the training partition.",
+ new DoubleValue(s.TrainingAccuracy)));
+ results.Add(new Result(TestAccuracyResultName,
+ "The accuracy of the Gaussian process solution on the test partition.",
+ new DoubleValue(s.TestAccuracy)));
+ } else {
+ results[SolutionParameterName].Value = s;
+ results[TrainingAccuracyResultName].Value = new DoubleValue(s.TrainingAccuracy);
+ results[TestAccuracyResultName].Value = new DoubleValue(s.TestAccuracy);
+ }
+ }
+ return base.Apply();
+ }
+ }
+}
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationModel.cs (revision 8623)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationModel.cs (revision 8623)
@@ -0,0 +1,63 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using HeuristicLab.Common;
+using HeuristicLab.Core;
+using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis {
+ ///
+ /// Represents a Gaussian process model.
+ ///
+ [StorableClass]
+ [Item("GaussianProcessDiscriminantFunctionClassificationModel",
+ "Represents a Gaussian process discriminant function classification model.")]
+ public sealed class GaussianProcessDiscriminantFunctionClassificationModel : DiscriminantFunctionClassificationModel {
+ [StorableConstructor]
+ private GaussianProcessDiscriminantFunctionClassificationModel(bool deserializing)
+ : base(deserializing) {
+ }
+
+ private GaussianProcessDiscriminantFunctionClassificationModel(
+ GaussianProcessDiscriminantFunctionClassificationModel original, Cloner cloner)
+ : base(original, cloner) {
+ }
+
+ public GaussianProcessDiscriminantFunctionClassificationModel(IGaussianProcessModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
+ : base(model, thresholdCalculator) {
+ }
+
+
+ public override IDeepCloneable Clone(Cloner cloner) {
+ return new GaussianProcessDiscriminantFunctionClassificationModel(this, cloner);
+ }
+
+
+ public override IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) {
+ return new GaussianProcessDiscriminantFunctionClassificationSolution(this, problemData);
+ }
+
+ public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
+ return CreateDiscriminantFunctionClassificationSolution(problemData);
+ }
+ }
+}
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationSolution.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationSolution.cs (revision 8623)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessDiscriminantFunctionClassificationSolution.cs (revision 8623)
@@ -0,0 +1,60 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using HeuristicLab.Common;
+using HeuristicLab.Core;
+using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis {
+ ///
+ /// Represents a Gaussian process model.
+ ///
+ [StorableClass]
+ [Item("GaussianProcessDiscriminantFunctionClassificationSolution",
+ "Represents a Gaussian process discriminant function classification solution.")]
+ public sealed class GaussianProcessDiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolution {
+ [StorableConstructor]
+ private GaussianProcessDiscriminantFunctionClassificationSolution(bool deserializing)
+ : base(deserializing) {
+ }
+
+ private GaussianProcessDiscriminantFunctionClassificationSolution(
+ GaussianProcessDiscriminantFunctionClassificationSolution original, Cloner cloner)
+ : base(original, cloner) {
+ }
+
+ public GaussianProcessDiscriminantFunctionClassificationSolution(GaussianProcessDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
+ : base(model, problemData) {
+ RecalculateResults();
+ }
+
+
+ public override IDeepCloneable Clone(Cloner cloner) {
+ return new GaussianProcessDiscriminantFunctionClassificationSolution(this, cloner);
+ }
+
+ protected override void RecalculateResults() {
+ CalculateResults();
+ CalculateRegressionResults();
+ }
+ }
+}
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs (revision 8623)
@@ -228,4 +228,5 @@
#endregion
+
private IEnumerable GetEstimatedValuesHelper(Dataset dataset, IEnumerable rows) {
var newX = AlglibUtil.PrepareAndScaleInputMatrix(dataset, allowedInputVariables, rows, inputScaling);
Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj (revision 8622)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj (revision 8623)
@@ -120,4 +120,9 @@
+
+
+
+
+
Index: /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/Interfaces/ISymbolicDiscriminantFunctionClassificationModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/Interfaces/ISymbolicDiscriminantFunctionClassificationModel.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/Interfaces/ISymbolicDiscriminantFunctionClassificationModel.cs (revision 8623)
@@ -22,5 +22,5 @@
namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
public interface ISymbolicDiscriminantFunctionClassificationModel : IDiscriminantFunctionClassificationModel, ISymbolicClassificationModel {
- IDiscriminantFunctionThresholdCalculator ThresholdCalculator { get; }
+
}
}
Index: /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationModel.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationModel.cs (revision 8623)
@@ -32,5 +32,6 @@
[StorableClass]
[Item(Name = "SymbolicClassificationModel", Description = "Represents a symbolic classification model.")]
- public abstract class SymbolicClassificationModel : SymbolicDataAnalysisModel, ISymbolicClassificationModel {
+ public abstract class
+ SymbolicClassificationModel : SymbolicDataAnalysisModel, ISymbolicClassificationModel {
[Storable]
private double lowerEstimationLimit;
Index: /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs (revision 8623)
@@ -51,4 +51,11 @@
}
+ private IDiscriminantFunctionThresholdCalculator thresholdCalculator;
+ [Storable]
+ public IDiscriminantFunctionThresholdCalculator ThresholdCalculator {
+ get { return thresholdCalculator; }
+ private set { thresholdCalculator = value; }
+ }
+
[StorableConstructor]
@@ -61,11 +68,17 @@
}
- public DiscriminantFunctionClassificationModel(IRegressionModel model)
+ public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
: base() {
this.name = ItemName;
this.description = ItemDescription;
this.model = model;
- this.classValues = new double[] { 0.0 };
- this.thresholds = new double[] { double.NegativeInfinity };
+ this.classValues = new double[0];
+ this.thresholds = new double[0];
+ this.thresholdCalculator = thresholdCalculator;
+ }
+
+ [StorableHook(HookType.AfterDeserialization)]
+ private void AfterDeserialization() {
+ if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator();
}
@@ -80,4 +93,14 @@
}
+ public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable rows) {
+ double[] classValues;
+ double[] thresholds;
+ var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
+ var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows);
+ thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
+ SetThresholdsAndClassValues(thresholds, classValues);
+ }
+
+
public IEnumerable GetEstimatedValues(Dataset dataset, IEnumerable rows) {
return model.GetEstimatedValues(dataset, rows);
@@ -85,4 +108,5 @@
public IEnumerable GetEstimatedClassValues(Dataset dataset, IEnumerable rows) {
+ if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model.");
foreach (var x in GetEstimatedValues(dataset, rows)) {
int classIndex = 0;
Index: /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs
===================================================================
--- /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs (revision 8623)
@@ -107,5 +107,5 @@
double maxDensityClassValue = -1;
foreach (var classValue in originalClasses) {
- double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
+ double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]);
if (density > maxDensity) {
maxDensity = density;
@@ -139,10 +139,6 @@
}
- private static double NormalDensity(double x, double mu, double sigma) {
- if (sigma.IsAlmost(0.0)) {
- if (x.IsAlmost(mu)) return 1.0; else return 0.0;
- } else {
- return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
- }
+ private static double LogNormalDensity(double x, double mu, double sigma) {
+ return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - ((x - mu) * (x - mu)) / (2.0 * sigma * sigma);
}
Index: /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs
===================================================================
--- /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs (revision 8622)
+++ /trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs (revision 8623)
@@ -26,4 +26,6 @@
IEnumerable Thresholds { get; }
IEnumerable ClassValues { get; }
+ IDiscriminantFunctionThresholdCalculator ThresholdCalculator { get; }
+ void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable rows);
// class values and thresholds can only be assigned simultanously
void SetThresholdsAndClassValues(IEnumerable thresholds, IEnumerable classValues);