Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.VRPEnhancements/HeuristicLab.Problems.VehicleRouting/3.4/Encodings/Potvin/Creators/Cluster.cs @ 14447

Last change on this file since 14447 was 14447, checked in by jzenisek, 7 years ago

#2707 updated generic type instantiation in cluster algorithm

File size: 7.7 KB
Line 
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using HeuristicLab.Core;
5using HeuristicLab.Problems.VehicleRouting.Interfaces;
6
7namespace HeuristicLab.Problems.VehicleRouting.Encodings.Potvin {
8
9  public class ClusterAlgorithm<TCluster,TClusterElement>
10    where TCluster : Cluster<TClusterElement>, new()
11    where TClusterElement : ClusterElement, new() {
12
13    public static List<TCluster> KMeans(IRandom random, List<TClusterElement> clusterElements,
14      int k, double changeThreshold) {
15      HashSet<int> initMeans = new HashSet<int>();
16      int nextMean = -1;
17      List<TCluster> clusters = CreateCList();
18
19      // (1) initialize each cluster with a random element as mean
20      for (int i = 0; i < k && i < clusterElements.Count; i++) {
21        TCluster cluster = new TCluster();
22        cluster.Id = i;
23
24        do {
25          nextMean = random.Next(0, clusterElements.Count);
26        } while (initMeans.Contains(nextMean));
27        initMeans.Add(nextMean);
28        cluster.SetMean(clusterElements[nextMean]);
29        clusters.Add(cluster);
30      }
31
32      // (2) repeat clustering until change rate is below threshold
33      double changeRate = 0.0;
34      do {
35        int changes = KMeansRun(clusters, clusterElements);
36        changeRate = (double)changes / clusterElements.Count;
37      } while (changeRate > changeThreshold);
38
39
40      // remove empty clusters
41      clusters.RemoveAll(c => c.Elements.Count.Equals(0));
42      return clusters;
43    }
44
45    private static int KMeansRun(List<TCluster> clusters, List<TClusterElement> clusterElements) {
46      int changes = 0;
47
48      // clear clusters from previous assigned elements
49      foreach (var c in clusters) {
50        c.Elements.Clear();
51      }
52
53      // assign elements to currently most suitable clusters
54      foreach (var e in clusterElements) {
55        int optClusterIdx = 0;
56        double optImpact = clusters[optClusterIdx].CalculateImpact(e);
57        for (int i = 1; i < clusters.Count; i++) {
58          double impact = clusters[i].CalculateImpact(e);
59          if (impact < optImpact) {
60            optImpact = impact;
61            optClusterIdx = i;
62          }
63        }
64        if (clusters[optClusterIdx].AddElement(e)) {
65          changes++;
66        }
67      }
68
69      // update mean and variance
70      foreach (var c in clusters) {
71        c.CalculateMean();
72        c.CalculateVariance();
73      }
74
75      return changes;
76    }
77
78    private static List<TCluster> CreateCList() {
79      var listType = typeof(List<>);
80      var constructedListType = listType.MakeGenericType(typeof(TCluster));
81      return (List<TCluster>)Activator.CreateInstance(constructedListType);
82    }
83  }
84
85  #region Cluster
86  public interface ICluster<T> where T : ClusterElement {
87    void SetMean(T o);
88    bool AddElement(T o);
89    void CalculateMean();
90    void CalculateVariance();
91    double CalculateImpact(T e);
92    double CalculateDistance(T e1, T e2);
93  }
94  public abstract class Cluster<T> : ICluster<T> where T : ClusterElement, new() {
95    public int Id;
96    public List<T> Elements { get; private set; }
97    public T Mean { get; set; }
98    public double Variance { get; set; }
99    protected Cluster() {
100      Elements = new List<T>();
101    }
102    protected Cluster(int id) {
103      Id = id;
104      Elements = new List<T>();
105    }
106    public bool AddElement(T e) {
107      Elements.Add(e);
108
109      bool clusterChanged = e.ClusterId != Id;
110      e.ClusterId = Id;
111      return clusterChanged;
112    }
113    public void SetMean(T e) {
114      Mean = e;
115      Mean.ClusterId = Id;
116    }
117    public abstract void CalculateMean();
118    public abstract double CalculateDistance(T e1, T e2);
119    public virtual void CalculateVariance() {
120      if (Mean == null)
121        CalculateMean();
122
123      Variance = 0.0;
124      foreach (T e in Elements) {
125        Variance += Math.Pow(CalculateDistance(Mean, e), 2);
126      }
127      Variance /= Elements.Count;
128    }
129    public virtual double CalculateImpact(T e) {
130      if (Mean == null)
131        CalculateMean();
132
133      double newVariance = (Variance * Elements.Count + Math.Pow(CalculateDistance(Mean, e), 2)) / (Elements.Count + 1);
134      return newVariance - Variance;
135    }
136  }
137  public class SpatialDistanceCluster : Cluster<SpatialDistanceClusterElement> {
138    public SpatialDistanceCluster() : base() {
139    }
140
141    public SpatialDistanceCluster(int id) : base(id) {
142    }
143    public override void CalculateMean() {
144      int dimensions = (Mean != null) ? Mean.Coordinates.Length : (Elements.Count > 0) ? Elements[0].Coordinates.Length : 0;
145      SpatialDistanceClusterElement mean = new SpatialDistanceClusterElement(-1, dimensions, Id);
146      foreach (SpatialDistanceClusterElement e in Elements) {
147        for (int i = 0; i < Mean.Coordinates.Length; i++) {
148          mean.Coordinates[i] += (e.Coordinates[i] / Elements.Count);
149        }
150      }
151      Mean = mean;
152    }
153    public override double CalculateDistance(SpatialDistanceClusterElement e1, SpatialDistanceClusterElement e2) {
154      if (!e1.Coordinates.Length.Equals(e2.Coordinates.Length)) {
155        throw new ArgumentException("Distance could not be calculated since number of dimensions is unequal.");
156      }
157
158      double distance = 0.0;
159      for (int i = 0; i < e1.Coordinates.Length; i++) {
160        distance += Math.Pow(e1.Coordinates[i] - e2.Coordinates[i], 2);
161      }
162      return Math.Sqrt(distance);
163    }
164  }
165  public class TemporalDistanceCluster : Cluster<TemporalDistanceClusterElement> {
166    public TemporalDistanceCluster() : base() {
167    }
168    public TemporalDistanceCluster(int id) : base(id) {
169    }
170    public override void CalculateMean() {
171      TemporalDistanceClusterElement mean = new TemporalDistanceClusterElement(-1, Id);
172      foreach (TemporalDistanceClusterElement e in Elements) {
173        mean.ReadyTime += e.ReadyTime/Elements.Count;
174        mean.DueTime += e.DueTime/Elements.Count;
175      }
176      Mean = mean;
177    }
178    public override double CalculateDistance(TemporalDistanceClusterElement e1, TemporalDistanceClusterElement e2) {
179      double distance = 0.0;
180
181      distance += Math.Pow(e1.ReadyTime - e2.ReadyTime, 2);
182      distance += Math.Pow(e1.DueTime - e2.DueTime, 2);
183
184      return Math.Sqrt(distance);
185    }
186  }
187  #endregion Cluster
188
189  #region ClusterElement
190  public abstract class ClusterElement {
191    public int Id { get; set; }
192    public int ClusterId { get; set; }
193    protected ClusterElement() {
194      ClusterId = -1;
195    }
196    public ClusterElement(int id, int clusterId = -1) {
197      Id = id;
198      ClusterId = clusterId;
199    }
200  }
201  public class SpatialDistanceClusterElement : ClusterElement {
202    public double[] Coordinates { get; set; }
203    public SpatialDistanceClusterElement() : base() {
204    }
205    public SpatialDistanceClusterElement(int id, int clusterId = -1) : base(id, clusterId) {
206    }
207    public SpatialDistanceClusterElement(int id, double[] coordinates, int clusterId = -1) : base(id, clusterId) {
208      Coordinates = coordinates;
209    }
210    public SpatialDistanceClusterElement(int id, int dimensions, int clusterId = -1) : base(id, clusterId) {
211      Coordinates = new double[dimensions];
212    }
213  }
214  public class TemporalDistanceClusterElement : ClusterElement {
215    public double ReadyTime { get; set; }
216    public double DueTime { get; set; }
217    public TemporalDistanceClusterElement() : base() {
218    }
219    public TemporalDistanceClusterElement(int id, int clusterId = -1) : base(id, clusterId) {
220    }
221    public TemporalDistanceClusterElement(int id, double readyTime, double dueTime, int clusterId = -1) : base(id, clusterId) {
222      ReadyTime = readyTime;
223      DueTime = dueTime;
224    }
225  }
226
227  #endregion ClusterElement
228}
Note: See TracBrowser for help on using the repository browser.