source: branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/WeightedEuclideanDistance.cs @ 16057

Last change on this file since 16057 was 16057, checked in by jkarder, 15 months ago

#2839:

File size: 5.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  [StorableClass]
35  [Item("WeightedEuclideanDistance", "A weighted norm function that uses Euclidean distance √(Σ(w[i]²*(p1[i]-p2[i])²))")]
36  public class WeightedEuclideanDistance : ParameterizedNamedItem, IDistance<IEnumerable<double>> {
37    [Storable]
38    private double[] weights;
39    public const string WeightsParameterName = "Weights";
40    public IValueParameter<DoubleArray> WeightsParameter {
41      get { return (IValueParameter<DoubleArray>) Parameters[WeightsParameterName]; }
42    }
43
44    public DoubleArray Weights {
45      get { return WeightsParameter.Value; }
46      set { WeightsParameter.Value = value; }
47    }
48
49    #region HLConstructors & Cloning
50    [StorableConstructor]
51    protected WeightedEuclideanDistance(bool deserializing) : base(deserializing) { }
52    [StorableHook(HookType.AfterDeserialization)]
53    private void AfterDeserialization() {
54      RegisterParameterEvents();
55    }
56    protected WeightedEuclideanDistance(WeightedEuclideanDistance original, Cloner cloner) : base(original, cloner) {
57      RegisterParameterEvents();
58      weights = original.weights != null ? original.weights.ToArray() : null;
59    }
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new WeightedEuclideanDistance(this, cloner);
62    }
63    public WeightedEuclideanDistance() {
64      Parameters.Add(new ValueParameter<DoubleArray>(WeightsParameterName, "The weights used to modify the euclidean distance.", new DoubleArray(new[] {1.0})));
65      RegisterParameterEvents();
66    }
67    #endregion
68
69    public static double GetDistance(IEnumerable<double> point1, IEnumerable<double> point2, IEnumerable<double> weights) {
70      using (IEnumerator<double> p1Enum = point1.GetEnumerator(), p2Enum = point2.GetEnumerator(), weEnum = weights.GetEnumerator()) {
71        var sum = 0.0;
72        while (p1Enum.MoveNext() & p2Enum.MoveNext() & weEnum.MoveNext()) {
73          var d = p1Enum.Current - p2Enum.Current;
74          var w = weEnum.Current;
75          sum += d * d * w * w;
76        }
77        if (weEnum.MoveNext() || p1Enum.MoveNext() || p2Enum.MoveNext()) throw new ArgumentException("Weighted Euclidean distance not defined on vectors of different length");
78        return Math.Sqrt(sum);
79      }
80    }
81
82    public double Get(IEnumerable<double> a, IEnumerable<double> b) {
83      return GetDistance(a, b, weights);
84    }
85    public IComparer<IEnumerable<double>> GetDistanceComparer(IEnumerable<double> item) {
86      return new DistanceBase<IEnumerable<double>>.DistanceComparer(item, this);
87    }
88    public double Get(object x, object y) {
89      return Get((IEnumerable<double>) x, (IEnumerable<double>) y);
90    }
91    public IComparer GetDistanceComparer(object item) {
92      return new DistanceBase<IEnumerable<double>>.DistanceComparer((IEnumerable<double>) item, this);
93    }
94
95    public void AdaptToProblemData(IDataAnalysisProblemData problemData) {
96      Weights = new DoubleArray(problemData.AllowedInputVariables.Select(v => Weights.ElementNames.Contains(v) ? GetWeight(v) : 1).ToArray())
97        {ElementNames = problemData.AllowedInputVariables};
98    }
99    public void Initialize(IDataAnalysisProblemData problemData) {
100      if (Weights.Length != problemData.AllowedInputVariables.Count()) throw new ArgumentException("Number of Weights does not match the number of input variables");
101      weights = Weights.ElementNames.All(v => v == null || v.Equals(string.Empty)) ? 
102        Weights.ToArray() : 
103        problemData.AllowedInputVariables.Select(GetWeight).ToArray();
104    }
105    private double GetWeight(string v) {
106      var w = Weights;
107      var names = w.ElementNames.ToArray();
108      for (var i = 0; i < w.Length; i++) if (names[i].Equals(v)) return w[i];
109      throw new ArgumentException("weigth for " + v + " was requested but not specified.");
110    }
111    private void RegisterParameterEvents() {
112      WeightsParameter.ValueChanged += OnWeightsArrayChanged;
113      WeightsParameter.Value.ItemChanged += OnWeightChanged;
114    }
115    private void OnWeightChanged(object sender, EventArgs<int> e) {
116      WeightsParameter.Value.ItemChanged -= OnWeightChanged;
117      Weights[e.Value] = Math.Max(0, Weights[e.Value]);
118      WeightsParameter.Value.ItemChanged -= OnWeightChanged;
119    }
120    private void OnWeightsArrayChanged(object sender, EventArgs e) {
121      for (var i = 0; i < Weights.Length; i++)
122        Weights[i] = Math.Max(0, Weights[i]);
123      WeightsParameter.Value.ItemChanged += OnWeightChanged;
124    }
125  }
126}
Note: See TracBrowser for help on using the repository browser.