source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/WeightedEuclideanDistance.cs @ 15532

Last change on this file since 15532 was 15532, checked in by bwerth, 2 years ago

#2850 merged Weighted TSNE to trunk

File size: 5.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [StorableClass]
35  [Item("WeightedEuclideanDistance", "A weighted norm function that uses Euclidean distance √(Σ(w[i]²*(p1[i]-p2[i])²))")]
36  public class WeightedEuclideanDistance : ParameterizedNamedItem, IDistance<IEnumerable<double>> {
37    [Storable]
38    private double[] weights;
39    public const string WeightsParameterName = "Weights";
40    public IValueParameter<DoubleArray> WeightsParameter {
41      get { return (IValueParameter<DoubleArray>) Parameters[WeightsParameterName]; }
42    }
43
44    public DoubleArray Weights {
45      get { return WeightsParameter.Value; }
46      set { WeightsParameter.Value = value; }
47    }
48
49    #region HLConstructors & Cloning
50    [StorableConstructor]
51    protected WeightedEuclideanDistance(bool deserializing) : base(deserializing) { }
52    private void AfterDeserialization() {
53      RegisterParameterEvents();
54    }
55    protected WeightedEuclideanDistance(WeightedEuclideanDistance original, Cloner cloner) : base(original, cloner) {
56      RegisterParameterEvents();
57      weights = original.weights != null ? original.weights.ToArray() : null;
58    }
59    public override IDeepCloneable Clone(Cloner cloner) {
60      return new WeightedEuclideanDistance(this, cloner);
61    }
62    public WeightedEuclideanDistance() {
63      Parameters.Add(new ValueParameter<DoubleArray>(WeightsParameterName, "The weights used to modify the euclidean distance.", new DoubleArray(new[] {1.0})));
64      RegisterParameterEvents();
65    }
66    #endregion
67
68    public static double GetDistance(IEnumerable<double> point1, IEnumerable<double> point2, IEnumerable<double> weights) {
69      using (IEnumerator<double> p1Enum = point1.GetEnumerator(), p2Enum = point2.GetEnumerator(), weEnum = weights.GetEnumerator()) {
70        var sum = 0.0;
71        while (p1Enum.MoveNext() & p2Enum.MoveNext() & weEnum.MoveNext()) {
72          var d = p1Enum.Current - p2Enum.Current;
73          var w = weEnum.Current;
74          sum += d * d * w * w;
75        }
76        if (weEnum.MoveNext() || p1Enum.MoveNext() || p2Enum.MoveNext()) throw new ArgumentException("Weighted Euclidean distance not defined on vectors of different length");
77        return Math.Sqrt(sum);
78      }
79    }
80
81    public double Get(IEnumerable<double> a, IEnumerable<double> b) {
82      return GetDistance(a, b, weights);
83    }
84    public IComparer<IEnumerable<double>> GetDistanceComparer(IEnumerable<double> item) {
85      return new DistanceBase<IEnumerable<double>>.DistanceComparer(item, this);
86    }
87    public double Get(object x, object y) {
88      return Get((IEnumerable<double>) x, (IEnumerable<double>) y);
89    }
90    public IComparer GetDistanceComparer(object item) {
91      return new DistanceBase<IEnumerable<double>>.DistanceComparer((IEnumerable<double>) item, this);
92    }
93
94    public void AdaptToProblemData(IDataAnalysisProblemData problemData) {
95      Weights = new DoubleArray(problemData.AllowedInputVariables.Select(v => Weights.ElementNames.Contains(v) ? GetWeight(v) : 1).ToArray())
96        {ElementNames = problemData.AllowedInputVariables};
97    }
98    public void Initialize(IDataAnalysisProblemData problemData) {
99      if (Weights.Length != problemData.AllowedInputVariables.Count()) throw new ArgumentException("Number of Weights does not match the number of input variables");
100      weights = Weights.ElementNames.All(v => v == null || v.Equals(string.Empty)) ? 
101        Weights.ToArray() : 
102        problemData.AllowedInputVariables.Select(GetWeight).ToArray();
103    }
104    private double GetWeight(string v) {
105      var w = Weights;
106      var names = w.ElementNames.ToArray();
107      for (var i = 0; i < w.Length; i++) if (names[i].Equals(v)) return w[i];
108      throw new ArgumentException("weigth for " + v + " was requested but not specified.");
109    }
110    private void RegisterParameterEvents() {
111      WeightsParameter.ValueChanged += OnWeightsArrayChanged;
112      WeightsParameter.Value.ItemChanged += OnWeightChanged;
113    }
114    private void OnWeightChanged(object sender, EventArgs<int> e) {
115      WeightsParameter.Value.ItemChanged -= OnWeightChanged;
116      Weights[e.Value] = Math.Max(0, Weights[e.Value]);
117      WeightsParameter.Value.ItemChanged -= OnWeightChanged;
118    }
119    private void OnWeightsArrayChanged(object sender, EventArgs e) {
120      for (var i = 0; i < Weights.Length; i++)
121        Weights[i] = Math.Max(0, Weights[i]);
122      WeightsParameter.Value.ItemChanged += OnWeightChanged;
123    }
124  }
125}
Note: See TracBrowser for help on using the repository browser.