Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/12/17 16:42:50 (7 years ago)
Author:
pfleck
Message:

#2707 Simplified k-means clustering for ClusterCreators.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.VRPEnhancements/HeuristicLab.Problems.VehicleRouting/3.4/Encodings/Potvin/Creators/Cluster.cs

    r14447 r14559  
    1 using System;
    2 using System.Collections;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
     5 * This file is part of HeuristicLab.
     6 *
     7 * HeuristicLab is free software: you can redistribute it and/or modify
     8 * it under the terms of the GNU General Public License as published by
     9 * the Free Software Foundation, either version 3 of the License, or
     10 * (at your option) any later version.
     11 *
     12 * HeuristicLab is distributed in the hope that it will be useful,
     13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15 * GNU General Public License for more details.
     16 *
     17 * You should have received a copy of the GNU General Public License
     18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     19 */
     20#endregion
     21
     22using System;
    323using System.Collections.Generic;
    424using HeuristicLab.Core;
    5 using HeuristicLab.Problems.VehicleRouting.Interfaces;
     25using HeuristicLab.Encodings.PermutationEncoding;
    626
    727namespace HeuristicLab.Problems.VehicleRouting.Encodings.Potvin {
    828
    9   public class ClusterAlgorithm<TCluster,TClusterElement>
    10     where TCluster : Cluster<TClusterElement>, new()
    11     where TClusterElement : ClusterElement, new() {
     29  public class KMeansAlgorithm<TData> where TData : class {
    1230
    13     public static List<TCluster> KMeans(IRandom random, List<TClusterElement> clusterElements,
    14       int k, double changeThreshold) {
    15       HashSet<int> initMeans = new HashSet<int>();
    16       int nextMean = -1;
    17       List<TCluster> clusters = CreateCList();
     31    private class ClusterInfo {
     32      public TData Mean;
     33      public double Variance;
     34      public int Size;
     35    }
     36
     37    private readonly Func<List<TData>, TData> meanCalculator;
     38    private readonly Func<TData, TData, double> distanceCalculator;
     39
     40    public KMeansAlgorithm(Func<List<TData>, TData> meanCalculator, Func<TData, TData, double> distanceCalculator) {
     41      this.meanCalculator = meanCalculator;
     42      this.distanceCalculator = distanceCalculator;
     43    }
     44
     45    public List<List<int>> Run(List<TData> data, int k, double changeThreshold, IRandom random) {
     46      int numClusters = Math.Min(k, data.Count);
     47
     48      var assignments = new int[data.Count];
     49      for (int i = 0; i < assignments.Length; i++)
     50        assignments[i] = -1; // unassigned
    1851
    1952      // (1) initialize each cluster with a random element as mean
    20       for (int i = 0; i < k && i < clusterElements.Count; i++) {
    21         TCluster cluster = new TCluster();
    22         cluster.Id = i;
    23 
    24         do {
    25           nextMean = random.Next(0, clusterElements.Count);
    26         } while (initMeans.Contains(nextMean));
    27         initMeans.Add(nextMean);
    28         cluster.SetMean(clusterElements[nextMean]);
    29         clusters.Add(cluster);
     53      var clusters = new ClusterInfo[numClusters];
     54      var initIndices = new Permutation(PermutationTypes.Absolute, data.Count, random);
     55      for (int c = 0; c < numClusters; c++) {
     56        assignments[initIndices[c]] = c;
     57        clusters[c] = new ClusterInfo {
     58          Mean = data[initIndices[c]],
     59          Size = 1
     60        };
    3061      }
    3162
    32       // (2) repeat clustering until change rate is below threshold
    33       double changeRate = 0.0;
     63      // (2) repeat clustering until change rate is below the threshold
     64      double changeRate;
    3465      do {
    35         int changes = KMeansRun(clusters, clusterElements);
    36         changeRate = (double)changes / clusterElements.Count;
     66        int changes = Iterate(data, assignments, clusters);
     67        changeRate = (double)changes / data.Count;
    3768      } while (changeRate > changeThreshold);
    3869
    39 
    40       // remove empty clusters
    41       clusters.RemoveAll(c => c.Elements.Count.Equals(0));
    42       return clusters;
     70      // (3) return non-empty clusters
     71      var clustersData = new List<List<int>>(numClusters);
     72      for (int c = 0; c < numClusters; c++)
     73        clustersData.Add(new List<int>());
     74      for (int i = 0; i < assignments.Length; i++)
     75        clustersData[assignments[i]].Add(i);
     76      clustersData.RemoveAll(c => c.Count == 0);
     77      return clustersData;
    4378    }
    4479
    45     private static int KMeansRun(List<TCluster> clusters, List<TClusterElement> clusterElements) {
     80    private int Iterate(List<TData> data, int[] assignments, ClusterInfo[] clusters) {
    4681      int changes = 0;
    4782
    48       // clear clusters from previous assigned elements
    49       foreach (var c in clusters) {
    50         c.Elements.Clear();
     83      var newAssignments = new int[data.Count];
     84      assignments.CopyTo(newAssignments, 0);
     85
     86      // assign elements to currently most suited cluster
     87      for (int i = 0; i < data.Count; i++) {
     88        int bestCluster = 0;
     89        double bestImpact = CalculateImpact(data[i], clusters[0]);
     90        for (int c = 1; c < clusters.Length; c++) {
     91          double impact = CalculateImpact(data[i], clusters[c]);
     92          if (impact < bestImpact) {
     93            bestImpact = impact;
     94            bestCluster = c;
     95          }
     96        }
     97        newAssignments[i] = bestCluster;
     98        if (newAssignments[i] != assignments[i])
     99          changes++;
    51100      }
    52101
    53       // assign elements to currently most suitable clusters
    54       foreach (var e in clusterElements) {
    55         int optClusterIdx = 0;
    56         double optImpact = clusters[optClusterIdx].CalculateImpact(e);
    57         for (int i = 1; i < clusters.Count; i++) {
    58           double impact = clusters[i].CalculateImpact(e);
    59           if (impact < optImpact) {
    60             optImpact = impact;
    61             optClusterIdx = i;
    62           }
    63         }
    64         if (clusters[optClusterIdx].AddElement(e)) {
    65           changes++;
    66         }
     102      // update clusters
     103      var clustersData = new List<List<TData>>(clusters.Length);
     104      for (int c = 0; c < clusters.Length; c++)
     105        clustersData.Add(new List<TData>());
     106      for (int i = 0; i < data.Count; i++) {
     107        assignments[i] = newAssignments[i];
     108        clustersData[assignments[i]].Add(data[i]);
    67109      }
     110      for (int c = 0; c < clusters.Length; c++) {
     111        var clusterData = clustersData[c];
     112        if (clusterData.Count == 0)
     113          continue;
    68114
    69       // update mean and variance
    70       foreach (var c in clusters) {
    71         c.CalculateMean();
    72         c.CalculateVariance();
     115        clusters[c].Mean = meanCalculator(clusterData);
     116
     117        clusters[c].Variance = 0;
     118        foreach (var e in clusterData)
     119          clusters[c].Variance += Math.Pow(distanceCalculator(e, clusters[c].Mean), 2);
     120        clusters[c].Variance /= clusterData.Count;
     121
     122        clusters[c].Size = clusterData.Count;
    73123      }
    74124
     
    76126    }
    77127
    78     private static List<TCluster> CreateCList() {
    79       var listType = typeof(List<>);
    80       var constructedListType = listType.MakeGenericType(typeof(TCluster));
    81       return (List<TCluster>)Activator.CreateInstance(constructedListType);
     128    private double CalculateImpact(TData datum, ClusterInfo cluster) {
     129      double newVariance = (cluster.Variance * cluster.Size + Math.Pow(distanceCalculator(datum, cluster.Mean), 2)) / (cluster.Size + 1);
     130      return newVariance - cluster.Variance;
    82131    }
    83132  }
    84 
    85   #region Cluster
    86   public interface ICluster<T> where T : ClusterElement {
    87     void SetMean(T o);
    88     bool AddElement(T o);
    89     void CalculateMean();
    90     void CalculateVariance();
    91     double CalculateImpact(T e);
    92     double CalculateDistance(T e1, T e2);
    93   }
    94   public abstract class Cluster<T> : ICluster<T> where T : ClusterElement, new() {
    95     public int Id;
    96     public List<T> Elements { get; private set; }
    97     public T Mean { get; set; }
    98     public double Variance { get; set; }
    99     protected Cluster() {
    100       Elements = new List<T>();
    101     }
    102     protected Cluster(int id) {
    103       Id = id;
    104       Elements = new List<T>();
    105     }
    106     public bool AddElement(T e) {
    107       Elements.Add(e);
    108 
    109       bool clusterChanged = e.ClusterId != Id;
    110       e.ClusterId = Id;
    111       return clusterChanged;
    112     }
    113     public void SetMean(T e) {
    114       Mean = e;
    115       Mean.ClusterId = Id;
    116     }
    117     public abstract void CalculateMean();
    118     public abstract double CalculateDistance(T e1, T e2);
    119     public virtual void CalculateVariance() {
    120       if (Mean == null)
    121         CalculateMean();
    122 
    123       Variance = 0.0;
    124       foreach (T e in Elements) {
    125         Variance += Math.Pow(CalculateDistance(Mean, e), 2);
    126       }
    127       Variance /= Elements.Count;
    128     }
    129     public virtual double CalculateImpact(T e) {
    130       if (Mean == null)
    131         CalculateMean();
    132 
    133       double newVariance = (Variance * Elements.Count + Math.Pow(CalculateDistance(Mean, e), 2)) / (Elements.Count + 1);
    134       return newVariance - Variance;
    135     }
    136   }
    137   public class SpatialDistanceCluster : Cluster<SpatialDistanceClusterElement> {
    138     public SpatialDistanceCluster() : base() {
    139     }
    140 
    141     public SpatialDistanceCluster(int id) : base(id) {
    142     }
    143     public override void CalculateMean() {
    144       int dimensions = (Mean != null) ? Mean.Coordinates.Length : (Elements.Count > 0) ? Elements[0].Coordinates.Length : 0;
    145       SpatialDistanceClusterElement mean = new SpatialDistanceClusterElement(-1, dimensions, Id);
    146       foreach (SpatialDistanceClusterElement e in Elements) {
    147         for (int i = 0; i < Mean.Coordinates.Length; i++) {
    148           mean.Coordinates[i] += (e.Coordinates[i] / Elements.Count);
    149         }
    150       }
    151       Mean = mean;
    152     }
    153     public override double CalculateDistance(SpatialDistanceClusterElement e1, SpatialDistanceClusterElement e2) {
    154       if (!e1.Coordinates.Length.Equals(e2.Coordinates.Length)) {
    155         throw new ArgumentException("Distance could not be calculated since number of dimensions is unequal.");
    156       }
    157 
    158       double distance = 0.0;
    159       for (int i = 0; i < e1.Coordinates.Length; i++) {
    160         distance += Math.Pow(e1.Coordinates[i] - e2.Coordinates[i], 2);
    161       }
    162       return Math.Sqrt(distance);
    163     }
    164   }
    165   public class TemporalDistanceCluster : Cluster<TemporalDistanceClusterElement> {
    166     public TemporalDistanceCluster() : base() {
    167     }
    168     public TemporalDistanceCluster(int id) : base(id) {
    169     }
    170     public override void CalculateMean() {
    171       TemporalDistanceClusterElement mean = new TemporalDistanceClusterElement(-1, Id);
    172       foreach (TemporalDistanceClusterElement e in Elements) {
    173         mean.ReadyTime += e.ReadyTime/Elements.Count;
    174         mean.DueTime += e.DueTime/Elements.Count;
    175       }
    176       Mean = mean;
    177     }
    178     public override double CalculateDistance(TemporalDistanceClusterElement e1, TemporalDistanceClusterElement e2) {
    179       double distance = 0.0;
    180 
    181       distance += Math.Pow(e1.ReadyTime - e2.ReadyTime, 2);
    182       distance += Math.Pow(e1.DueTime - e2.DueTime, 2);
    183 
    184       return Math.Sqrt(distance);
    185     }
    186   }
    187   #endregion Cluster
    188 
    189   #region ClusterElement
    190   public abstract class ClusterElement {
    191     public int Id { get; set; }
    192     public int ClusterId { get; set; }
    193     protected ClusterElement() {
    194       ClusterId = -1;
    195     }
    196     public ClusterElement(int id, int clusterId = -1) {
    197       Id = id;
    198       ClusterId = clusterId;
    199     }
    200   }
    201   public class SpatialDistanceClusterElement : ClusterElement {
    202     public double[] Coordinates { get; set; }
    203     public SpatialDistanceClusterElement() : base() {
    204     }
    205     public SpatialDistanceClusterElement(int id, int clusterId = -1) : base(id, clusterId) {
    206     }
    207     public SpatialDistanceClusterElement(int id, double[] coordinates, int clusterId = -1) : base(id, clusterId) {
    208       Coordinates = coordinates;
    209     }
    210     public SpatialDistanceClusterElement(int id, int dimensions, int clusterId = -1) : base(id, clusterId) {
    211       Coordinates = new double[dimensions];
    212     }
    213   }
    214   public class TemporalDistanceClusterElement : ClusterElement {
    215     public double ReadyTime { get; set; }
    216     public double DueTime { get; set; }
    217     public TemporalDistanceClusterElement() : base() {
    218     }
    219     public TemporalDistanceClusterElement(int id, int clusterId = -1) : base(id, clusterId) {
    220     }
    221     public TemporalDistanceClusterElement(int id, double readyTime, double dueTime, int clusterId = -1) : base(id, clusterId) {
    222       ReadyTime = readyTime;
    223       DueTime = dueTime;
    224     }
    225   }
    226 
    227   #endregion ClusterElement
    228133}
Note: See TracChangeset for help on using the changeset viewer.