source: branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/Nca/NcaAlgorithm.cs @ 16057

Last change on this file since 16057 was 16057, checked in by jkarder, 2 years ago

#2839:

File size: 13.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Algorithms.GradientDescent;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.PluginInfrastructure;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Random;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Neighborhood Components Analysis
39  /// </summary>
40  [Item("Neighborhood Components Analysis (NCA)", @"Implementation of Neighborhood Components Analysis
41based on the description of J. Goldberger, S. Roweis, G. Hinton, R. Salakhutdinov. 2005.
42Neighbourhood Component Analysis. Advances in Neural Information Processing Systems, 17. pp. 513-520
43with additional regularizations described in Z. Yang, J. Laaksonen. 2007.
44Regularized Neighborhood Component Analysis. Lecture Notes in Computer Science, 4522. pp. 253-262.")]
45  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 170)]
46  [StorableClass]
47  public sealed class NcaAlgorithm : EngineAlgorithm {
48    #region Parameter Names
49    private const string SeedParameterName = "Seed";
50    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
51    private const string KParameterName = "K";
52    private const string DimensionsParameterName = "Dimensions";
53    private const string InitializationParameterName = "Initialization";
54    private const string NeighborSamplesParameterName = "NeighborSamples";
55    private const string IterationsParameterName = "Iterations";
56    private const string RegularizationParameterName = "Regularization";
57    private const string NcaModelCreatorParameterName = "NcaModelCreator";
58    private const string NcaSolutionCreatorParameterName = "NcaSolutionCreator";
59    private const string ApproximateGradientsParameterName = "ApproximateGradients";
60    private const string NcaMatrixParameterName = "NcaMatrix";
61    private const string NcaMatrixGradientsParameterName = "NcaMatrixGradients";
62    private const string QualityParameterName = "Quality";
63    #endregion
64
65    public override Type ProblemType { get { return typeof(IClassificationProblem); } }
66    public new IClassificationProblem Problem {
67      get { return (IClassificationProblem)base.Problem; }
68      set { base.Problem = value; }
69    }
70
71    #region Parameter Properties
72    public IValueParameter<IntValue> SeedParameter {
73      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
74    }
75    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
76      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
77    }
78    public IFixedValueParameter<IntValue> KParameter {
79      get { return (IFixedValueParameter<IntValue>)Parameters[KParameterName]; }
80    }
81    public IFixedValueParameter<IntValue> DimensionsParameter {
82      get { return (IFixedValueParameter<IntValue>)Parameters[DimensionsParameterName]; }
83    }
84    public IConstrainedValueParameter<INcaInitializer> InitializationParameter {
85      get { return (IConstrainedValueParameter<INcaInitializer>)Parameters[InitializationParameterName]; }
86    }
87    public IFixedValueParameter<IntValue> NeighborSamplesParameter {
88      get { return (IFixedValueParameter<IntValue>)Parameters[NeighborSamplesParameterName]; }
89    }
90    public IFixedValueParameter<IntValue> IterationsParameter {
91      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
92    }
93    public IFixedValueParameter<DoubleValue> RegularizationParameter {
94      get { return (IFixedValueParameter<DoubleValue>)Parameters[RegularizationParameterName]; }
95    }
96    public IValueParameter<BoolValue> ApproximateGradientsParameter {
97      get { return (IValueParameter<BoolValue>)Parameters[ApproximateGradientsParameterName]; }
98    }
99    public IValueParameter<INcaModelCreator> NcaModelCreatorParameter {
100      get { return (IValueParameter<INcaModelCreator>)Parameters[NcaModelCreatorParameterName]; }
101    }
102    public IValueParameter<INcaSolutionCreator> NcaSolutionCreatorParameter {
103      get { return (IValueParameter<INcaSolutionCreator>)Parameters[NcaSolutionCreatorParameterName]; }
104    }
105    #endregion
106
107    #region Properties
108    public int Seed {
109      get { return SeedParameter.Value.Value; }
110      set { SeedParameter.Value.Value = value; }
111    }
112    public bool SetSeedRandomly {
113      get { return SetSeedRandomlyParameter.Value.Value; }
114      set { SetSeedRandomlyParameter.Value.Value = value; }
115    }
116    public int K {
117      get { return KParameter.Value.Value; }
118      set { KParameter.Value.Value = value; }
119    }
120    public int Dimensions {
121      get { return DimensionsParameter.Value.Value; }
122      set { DimensionsParameter.Value.Value = value; }
123    }
124    public int NeighborSamples {
125      get { return NeighborSamplesParameter.Value.Value; }
126      set { NeighborSamplesParameter.Value.Value = value; }
127    }
128    public int Iterations {
129      get { return IterationsParameter.Value.Value; }
130      set { IterationsParameter.Value.Value = value; }
131    }
132    public double Regularization {
133      get { return RegularizationParameter.Value.Value; }
134      set { RegularizationParameter.Value.Value = value; }
135    }
136    public INcaModelCreator NcaModelCreator {
137      get { return NcaModelCreatorParameter.Value; }
138      set { NcaModelCreatorParameter.Value = value; }
139    }
140    public INcaSolutionCreator NcaSolutionCreator {
141      get { return NcaSolutionCreatorParameter.Value; }
142      set { NcaSolutionCreatorParameter.Value = value; }
143    }
144    #endregion
145
146    [StorableConstructor]
147    private NcaAlgorithm(bool deserializing) : base(deserializing) { }
148    private NcaAlgorithm(NcaAlgorithm original, Cloner cloner) : base(original, cloner) { }
149    public NcaAlgorithm()
150      : base() {
151      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The seed of the random number generator.", new IntValue(0)));
152      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "A boolean flag that indicates whether the seed should be randomly reset each time the algorithm is run.", new BoolValue(true)));
153      Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "The K for the nearest neighbor.", new IntValue(3)));
154      Parameters.Add(new FixedValueParameter<IntValue>(DimensionsParameterName, "The number of dimensions that NCA should reduce the data to.", new IntValue(2)));
155      Parameters.Add(new ConstrainedValueParameter<INcaInitializer>(InitializationParameterName, "Which method should be used to initialize the matrix. Typically LDA (linear discriminant analysis) should provide a good estimate."));
156      Parameters.Add(new FixedValueParameter<IntValue>(NeighborSamplesParameterName, "How many of the neighbors should be sampled in order to speed up the calculation. This should be at least the value of k and at most the number of training instances minus one will be used.", new IntValue(60)));
157      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "How many iterations the conjugate gradient (CG) method should be allowed to perform. The method might still terminate earlier if a local optima has already been reached.", new IntValue(50)));
158      Parameters.Add(new FixedValueParameter<DoubleValue>(RegularizationParameterName, "A non-negative paramter which can be set to increase generalization and avoid overfitting. If set to 0 the algorithm is similar to NCA as proposed by Goldberger et al.", new DoubleValue(0)));
159      Parameters.Add(new ValueParameter<INcaModelCreator>(NcaModelCreatorParameterName, "Creates an NCA model out of the matrix.", new NcaModelCreator()));
160      Parameters.Add(new ValueParameter<INcaSolutionCreator>(NcaSolutionCreatorParameterName, "Creates an NCA solution given a model and some data.", new NcaSolutionCreator()));
161      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "True if the gradient should be approximated otherwise they are computed exactly.", new BoolValue()));
162
163      NcaSolutionCreatorParameter.Hidden = true;
164      ApproximateGradientsParameter.Hidden = true;
165
166      INcaInitializer defaultInitializer = null;
167      foreach (var initializer in ApplicationManager.Manager.GetInstances<INcaInitializer>().OrderBy(x => x.ItemName)) {
168        if (initializer is LdaInitializer) defaultInitializer = initializer;
169        InitializationParameter.ValidValues.Add(initializer);
170      }
171      if (defaultInitializer != null) InitializationParameter.Value = defaultInitializer;
172
173      var randomCreator = new RandomCreator();
174      var ncaInitializer = new Placeholder();
175      var bfgsInitializer = new LbfgsInitializer();
176      var makeStep = new LbfgsMakeStep();
177      var branch = new ConditionalBranch();
178      var gradientCalculator = new NcaGradientCalculator();
179      var modelCreator = new Placeholder();
180      var updateResults = new LbfgsUpdateResults();
181      var analyzer = new LbfgsAnalyzer();
182      var finalModelCreator = new Placeholder();
183      var finalAnalyzer = new LbfgsAnalyzer();
184      var solutionCreator = new Placeholder();
185
186      OperatorGraph.InitialOperator = randomCreator;
187      randomCreator.SeedParameter.ActualName = SeedParameterName;
188      randomCreator.SeedParameter.Value = null;
189      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
190      randomCreator.SetSeedRandomlyParameter.Value = null;
191      randomCreator.Successor = ncaInitializer;
192
193      ncaInitializer.Name = "(NcaInitializer)";
194      ncaInitializer.OperatorParameter.ActualName = InitializationParameterName;
195      ncaInitializer.Successor = bfgsInitializer;
196
197      bfgsInitializer.IterationsParameter.ActualName = IterationsParameterName;
198      bfgsInitializer.PointParameter.ActualName = NcaMatrixParameterName;
199      bfgsInitializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
200      bfgsInitializer.Successor = makeStep;
201
202      makeStep.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
203      makeStep.PointParameter.ActualName = NcaMatrixParameterName;
204      makeStep.Successor = branch;
205
206      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
207      branch.FalseBranch = gradientCalculator;
208      branch.TrueBranch = finalModelCreator;
209
210      gradientCalculator.Successor = modelCreator;
211
212      modelCreator.OperatorParameter.ActualName = NcaModelCreatorParameterName;
213      modelCreator.Successor = updateResults;
214
215      updateResults.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
216      updateResults.QualityParameter.ActualName = QualityParameterName;
217      updateResults.QualityGradientsParameter.ActualName = NcaMatrixGradientsParameterName;
218      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
219      updateResults.Successor = analyzer;
220
221      analyzer.QualityParameter.ActualName = QualityParameterName;
222      analyzer.PointParameter.ActualName = NcaMatrixParameterName;
223      analyzer.QualityGradientsParameter.ActualName = NcaMatrixGradientsParameterName;
224      analyzer.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
225      analyzer.PointsTableParameter.ActualName = "Matrix table";
226      analyzer.QualityGradientsTableParameter.ActualName = "Gradients table";
227      analyzer.QualitiesTableParameter.ActualName = "Qualities";
228      analyzer.Successor = makeStep;
229
230      finalModelCreator.OperatorParameter.ActualName = NcaModelCreatorParameterName;
231      finalModelCreator.Successor = finalAnalyzer;
232
233      finalAnalyzer.QualityParameter.ActualName = QualityParameterName;
234      finalAnalyzer.PointParameter.ActualName = NcaMatrixParameterName;
235      finalAnalyzer.QualityGradientsParameter.ActualName = NcaMatrixGradientsParameterName;
236      finalAnalyzer.PointsTableParameter.ActualName = analyzer.PointsTableParameter.ActualName;
237      finalAnalyzer.QualityGradientsTableParameter.ActualName = analyzer.QualityGradientsTableParameter.ActualName;
238      finalAnalyzer.QualitiesTableParameter.ActualName = analyzer.QualitiesTableParameter.ActualName;
239      finalAnalyzer.Successor = solutionCreator;
240
241      solutionCreator.OperatorParameter.ActualName = NcaSolutionCreatorParameterName;
242
243      Problem = new ClassificationProblem();
244    }
245
246    public override IDeepCloneable Clone(Cloner cloner) {
247      return new NcaAlgorithm(this, cloner);
248    }
249
250    public override void Prepare() {
251      if (Problem != null) base.Prepare();
252    }
253  }
254}
Note: See TracBrowser for help on using the repository browser.