Free cookie consent management tool by TermsFeed Policy Generator

source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/VPTree.cs @ 14518

Last change on this file since 14518 was 14518, checked in by bwerth, 7 years ago

#2700 TSNEAnalysis is now a BasicAlg, hid some Parameters, added optional data normalization to make TSNE scaling-invariant

File size: 8.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20
21//Code is based on an implementation from Laurens van der Maaten
22
23/*
24*
25* Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
26* All rights reserved.
27*
28* Redistribution and use in source and binary forms, with or without
29* modification, are permitted provided that the following conditions are met:
30* 1. Redistributions of source code must retain the above copyright
31*    notice, this list of conditions and the following disclaimer.
32* 2. Redistributions in binary form must reproduce the above copyright
33*    notice, this list of conditions and the following disclaimer in the
34*    documentation and/or other materials provided with the distribution.
35* 3. All advertising materials mentioning features or use of this software
36*    must display the following acknowledgement:
37*    This product includes software developed by the Delft University of Technology.
38* 4. Neither the name of the Delft University of Technology nor the names of
39*    its contributors may be used to endorse or promote products derived from
40*    this software without specific prior written permission.
41*
42* THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
43* OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
44* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
45* EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
46* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
47* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
49* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
50* IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
51* OF SUCH DAMAGE.
52*
53*/
54#endregion
55
56using System;
57using System.Collections.Generic;
58using System.Linq;
59using HeuristicLab.Common;
60using HeuristicLab.Core;
61using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
62using HeuristicLab.Random;
63
64namespace HeuristicLab.Algorithms.DataAnalysis {
65  [StorableClass]
66  public class VPTree<T> : DeepCloneable, IVPTree<T> where T : class, IDeepCloneable {
67    #region properties
68    [Storable]
69    private List<T> items;
70    [Storable]
71    private double tau;
72    [Storable]
73    private Node root;
74    [Storable]
75    private IDistance<T> distance;
76    #endregion
77
78    #region HLConstructors & Cloning
79    [StorableConstructor]
80    protected VPTree(bool deserializing) { }
81    protected VPTree(VPTree<T> original, Cloner cloner)
82      : base(original, cloner) {
83      items = original.items.Select(cloner.Clone).ToList();
84      tau = original.tau;
85      root = cloner.Clone(original.root);
86      distance = cloner.Clone(distance);
87    }
88    public override IDeepCloneable Clone(Cloner cloner) { return new VPTree<T>(this, cloner); }
89    public VPTree(IDistance<T> distance) {
90      root = null;
91      this.distance = distance;
92    }
93    #endregion
94
95    public void Create(IEnumerable<T> items) {
96      this.items = items.Select(x => x).ToList();
97      root = BuildFromPoints(0, this.items.Count);
98    }
99
100    public void Search(T target, int k, out List<T> results, out List<double> distances) {
101      IHeap<double, HeapItem> heap = new PriorityQueue<double, HeapItem>(double.MaxValue, double.MinValue, k);
102      tau = double.MaxValue;
103      Search(root, target, k, heap);
104      results = new List<T>();
105      distances = new List<double>();
106      while (heap.Size > 0) {
107        results.Add(items[heap.PeekMinValue().Index]);
108        distances.Add(heap.PeekMinValue().Dist);
109        heap.RemoveMin();
110      }
111      results.Reverse();
112      distances.Reverse();
113    }
114
115    private Node BuildFromPoints(int lower, int upper) {
116      if (upper == lower)      // indicates that we're done here!
117        return null;
118
119      // Lower index is center of current node
120      var node = new Node { index = lower };
121      var r = new MersenneTwister(); //outer behaviour does not change with the random seed => no need to take the IRandom from the algorithm
122      if (upper - lower <= 1) return node; // if we did not arrive at leaf yet
123
124      // Choose an arbitrary point and move it to the start
125      var i = (int)(r.NextDouble() / 1 * (upper - lower - 1)) + lower;
126      items.Swap(lower, i);
127
128      // Partition around the median distance
129      var median = (upper + lower) / 2;
130      items.NthElement(lower + 1, upper - 1, median, distance.GetDistanceComparer(items[lower]));
131
132      // Threshold of the new node will be the distance to the median
133      node.threshold = distance.Get(items[lower], items[median]);
134
135      // Recursively build tree
136      node.index = lower;
137      node.left = BuildFromPoints(lower + 1, median);
138      node.right = BuildFromPoints(median, upper);
139
140      // Return result
141      return node;
142    }
143
144    private void Search(Node node, T target, int k, IHeap<double, HeapItem> heap) {
145      if (node == null) return;
146      var dist = distance.Get(items[node.index], target);
147      if (dist < tau) {
148        if (heap.Size == k) heap.RemoveMin();
149        heap.Insert(-dist, new HeapItem(node.index, dist));//TODO check if minheap or maxheap schould be used here
150        if (heap.Size == k) tau = heap.PeekMinValue().Dist;
151      }
152      if (node.left == null && node.right == null) return;
153
154      if (dist < node.threshold) {
155        if (dist - tau <= node.threshold) Search(node.left, target, k, heap);   // if there can still be neighbors inside the ball, recursively search left child first
156        if (dist + tau >= node.threshold) Search(node.right, target, k, heap);  // if there can still be neighbors outside the ball, recursively search right child
157      } else {
158        if (dist + tau >= node.threshold) Search(node.right, target, k, heap);  // if there can still be neighbors outside the ball, recursively search right child first
159        if (dist - tau <= node.threshold) Search(node.left, target, k, heap);   // if there can still be neighbors inside the ball, recursively search left child
160      }
161
162    }
163
164    [StorableClass]
165    private sealed class Node : Item {
166      [Storable]
167      public int index;
168      [Storable]
169      public double threshold;
170      [Storable]
171      public Node left;
172      [Storable]
173      public Node right;
174
175      #region HLConstructors & Cloning
176      [StorableConstructor]
177      private Node(bool deserializing) : base(deserializing) { }
178      private Node(Node original, Cloner cloner) : base(original, cloner) {
179        index = original.index;
180        threshold = original.threshold;
181        left = (Node)original.left.Clone(cloner);
182        right = (Node)original.right.Clone(cloner);
183      }
184      internal Node() {
185        index = 0;
186        threshold = 0;
187        left = null;
188        right = null;
189      }
190      public override IDeepCloneable Clone(Cloner cloner) {
191        return new Node(this, cloner);
192      }
193      #endregion
194    }
195
196    [StorableClass]
197    private sealed class HeapItem : Item, IComparable<HeapItem> {
198      [Storable]
199      public int Index;
200      [Storable]
201      public double Dist;
202
203      #region HLConstructors & Cloning
204      [StorableConstructor]
205      private HeapItem(bool deserializing) : base(deserializing) { }
206      private HeapItem(HeapItem original, Cloner cloner)
207      : base(original, cloner) {
208        Index = original.Index;
209        Dist = original.Dist;
210      }
211      public override IDeepCloneable Clone(Cloner cloner) { return new HeapItem(this, cloner); }
212      public HeapItem(int index, double dist) {
213        Index = index;
214        Dist = dist;
215      }
216      #endregion
217
218      public int CompareTo(HeapItem other) {
219        return Dist.CompareTo(Dist);
220      }
221    }
222  }
223}
Note: See TracBrowser for help on using the repository browser.