- Timestamp:
- 12/22/16 10:08:25 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAnalysis.cs
r14512 r14518 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Drawing; 24 25 using System.Linq; 26 using System.Threading; 25 27 using HeuristicLab.Analysis; 26 28 using HeuristicLab.Common; … … 28 30 using HeuristicLab.Data; 29 31 using HeuristicLab.Encodings.RealVectorEncoding; 32 using HeuristicLab.Optimization; 30 33 using HeuristicLab.Parameters; 31 34 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; … … 40 43 [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)] 41 44 [StorableClass] 42 public sealed class TSNEAnalysis : FixedDataAnalysisAlgorithm<IRegressionProblem> { 43 45 public sealed class TSNEAnalysis : BasicAlgorithm { 46 47 public override Type ProblemType 48 { 49 get { return typeof(IDataAnalysisProblem); } 50 } 51 public new IDataAnalysisProblem Problem 52 { 53 get { return (IDataAnalysisProblem)base.Problem; } 54 set { base.Problem = value; } 55 } 44 56 #region Resultnames 45 57 private const string ScatterPlotResultName = "Scatterplot"; … … 61 73 private const string SeedParameterName = "Seed"; 62 74 private const string ClassesParameterName = "ClassNames"; 75 private const string NormalizationParameterName = "Normalization"; 63 76 #endregion 64 77 … … 115 128 { 116 129 get { return Parameters[ClassesParameterName] as IFixedValueParameter<StringValue>; } 130 } 131 public IFixedValueParameter<BoolValue> NormalizationParameter 132 { 133 get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; } 117 134 } 118 135 #endregion … … 174 191 get { return ClassesParameter.Value.Value; } 175 192 } 176 193 public bool Normalization 194 { 195 get { return NormalizationParameter.Value.Value; } 196 } 177 197 [Storable] 178 198 public TSNE<RealVector> tsne; … … 191 211 Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis", new IntValue(2))); 192 212 Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent", new IntValue(1000))); 193 Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated", new IntValue( 250)));194 Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched", new IntValue( 250)));213 Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated", new IntValue(0))); 214 Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched", new IntValue(0))); 195 215 Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent", new DoubleValue(0.5))); 196 216 Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum", new DoubleValue(0.8))); … … 199 219 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random", new IntValue(0))); 200 220 Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. \n if the lable column can not be found Training/Test is used as labels", new StringValue("none"))); 201 } 202 #endregion 203 204 protected override void Run() { 205 var data = CalculateProjectedData(Problem.ProblemData); 206 var lowDimData = new DoubleMatrix(data); 207 } 221 Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Wether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored", new BoolValue(true))); 222 223 MomentumSwitchIterationParameter.Hidden = true; 224 InitialMomentumParameter.Hidden = true; 225 FinalMomentumParameter.Hidden = true; 226 StopLyingIterationParameter.Hidden = true; 227 EtaParameter.Hidden = true; 228 } 229 #endregion 208 230 209 231 public override void Stop() { … … 212 234 } 213 235 236 protected override void Run(CancellationToken cancellationToken) { 237 var data = CalculateProjectedData(Problem.ProblemData); 238 var lowDimData = new DoubleMatrix(data); 239 } 240 214 241 private double[,] CalculateProjectedData(IDataAnalysisProblemData problemData) { 215 var DataRowNames = new Dictionary<string, List<int>>();242 var dataRowNames = new Dictionary<string, List<int>>(); 216 243 var rows = new Dictionary<string, ScatterPlotDataRow>(); 217 244 … … 220 247 var classes = problemData.Dataset.GetStringValues(Classes).ToArray(); 221 248 for (int i = 0; i < classes.Length; i++) { 222 if (! DataRowNames.ContainsKey(classes[i])) DataRowNames.Add(classes[i], new List<int>());223 DataRowNames[classes[i]].Add(i); //always succeeds249 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 250 dataRowNames[classes[i]].Add(i); //always succeeds 224 251 } 225 252 } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) { … … 228 255 var min = classValues.Min() - 0.1; 229 256 var contours = 8; 230 for ( inti = 0; i < contours; i++) {257 for (var i = 0; i < contours; i++) { 231 258 var name = GetContourName(i, min, max, contours); 232 DataRowNames.Add(name, new List<int>());259 dataRowNames.Add(name, new List<int>()); 233 260 rows.Add(name, new ScatterPlotDataRow(name, "", new List<Point2D<double>>())); 234 261 rows[name].VisualProperties.Color = GetHeatMapColor(i, contours); 235 rows[name].VisualProperties.PointSize = i +3;262 rows[name].VisualProperties.PointSize = i + 3; 236 263 } 237 264 for (int i = 0; i < classValues.Length; i++) { 238 DataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds265 dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds 239 266 } 240 267 241 268 } 242 269 243 244 270 } else { 245 DataRowNames.Add("Training", problemData.TrainingIndices.ToList());246 DataRowNames.Add("Test", problemData.TestIndices.ToList());271 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 272 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 247 273 } 248 274 249 275 var random = SetSeedRandomly ? new MersenneTwister() : new MersenneTwister(Seed); 250 tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, DataRowNames, rows);276 tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows); 251 277 var dataset = problemData.Dataset; 252 278 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 253 279 var data = new RealVector[dataset.Rows]; 254 280 for (var row = 0; row < dataset.Rows; row++) data[row] = new RealVector(allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray()); 281 282 if (Normalization) { 283 data = NormalizeData(data); 284 } 285 255 286 return tsne.Run(data, NewDimensions, Perplexity, Theta); 287 } 288 289 private RealVector[] NormalizeData(RealVector[] data) { 290 var n = data[0].Length; 291 var mean = new double[n]; 292 var sd = new double[n]; 293 var nData = new RealVector[data.Length]; 294 for (var i = 0; i < n; i++) { 295 var i1 = i; 296 sd[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).StandardDeviation(); 297 mean[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).Average(); 298 } 299 for (int i = 0; i < data.Length; i++) { 300 nData[i] = new RealVector(n); 301 for (int j = 0; j < n; j++) { 302 nData[i][j] = (data[i][j] - mean[j]) / sd[j]; 303 } 304 } 305 return nData; 306 307 256 308 } 257 309
Note: See TracChangeset
for help on using the changeset viewer.