Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/11/11 15:03:46 (14 years ago)
Author:
gkronber
Message:

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs

    r4068 r5275  
    2121
    2222using System;
     23using System.Collections.Generic;
    2324using System.Linq;
     25using HeuristicLab.Common;
    2426using HeuristicLab.Core;
    2527using HeuristicLab.Data;
     
    125127    #endregion
    126128
     129    [StorableConstructor]
     130    protected SupportVectorMachineCrossValidationEvaluator(bool deserializing) : base(deserializing) { }
     131
     132    protected SupportVectorMachineCrossValidationEvaluator(SupportVectorMachineCrossValidationEvaluator original,
     133      Cloner cloner)
     134      : base(original, cloner) { }
    127135    public SupportVectorMachineCrossValidationEvaluator()
    128136      : base() {
     
    142150    }
    143151
     152    public override IDeepCloneable Clone(Cloner cloner) {
     153      return new SupportVectorMachineCrossValidationEvaluator(this, cloner);
     154    }
     155
    144156    public override IOperation Apply() {
    145       double reductionRatio = 1.0;
     157      double reductionRatio = 1.0; // TODO: make parameter
    146158      if (ActualSamplesParameter.ActualValue != null)
    147159        reductionRatio = ActualSamplesParameter.ActualValue.Value;
    148 
    149       int reducedRows = (int)((SamplesEnd.Value - SamplesStart.Value) * reductionRatio);
     160      IEnumerable<int> rows =
     161        Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)
     162        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
     163
     164      // create a new DataAnalysisProblemData instance
    150165      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
    151       reducedProblemData.Dataset = CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, reductionRatio, SamplesStart.Value, SamplesEnd.Value);
     166      reducedProblemData.Dataset =
     167        CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, rows, reductionRatio);
     168      reducedProblemData.TrainingSamplesStart.Value = 0;
     169      reducedProblemData.TrainingSamplesEnd.Value = reducedProblemData.Dataset.Rows;
     170      reducedProblemData.TestSamplesStart.Value = reducedProblemData.Dataset.Rows;
     171      reducedProblemData.TestSamplesEnd.Value = reducedProblemData.Dataset.Rows;
     172      reducedProblemData.ValidationPercentage.Value = 0;
    152173
    153174      double quality = PerformCrossValidation(reducedProblemData,
    154                              SamplesStart.Value, SamplesStart.Value + reducedRows,
    155175                             SvmType.Value, KernelType.Value,
    156176                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
     
    160180    }
    161181
    162     private Dataset CreateReducedDataset(IRandom random, Dataset dataset, double reductionRatio, int start, int end) {
    163       int n = (int)((end - start) * reductionRatio);
     182    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, IEnumerable<int> rowIndices, double reductionRatio) {
     183
    164184      // must not make a fink:
    165185      // => select n rows randomly from start..end
     
    168188
    169189      // all possible rowIndexes from start..end
    170       int[] rowIndexes = Enumerable.Range(start, end - start).ToArray();
     190      int[] rowIndexArr = rowIndices.ToArray();
     191      int n = (int)Math.Max(1.0, rowIndexArr.Length * reductionRatio);
    171192
    172193      // knuth shuffle
    173       for (int i = rowIndexes.Length - 1; i > 0; i--) {
     194      for (int i = rowIndexArr.Length - 1; i > 0; i--) {
    174195        int j = random.Next(0, i);
    175196        // swap
    176         int tmp = rowIndexes[i];
    177         rowIndexes[i] = rowIndexes[j];
    178         rowIndexes[j] = tmp;
     197        int tmp = rowIndexArr[i];
     198        rowIndexArr[i] = rowIndexArr[j];
     199        rowIndexArr[j] = tmp;
    179200      }
    180201
    181202      // take the first n indexes (selected n rowIndexes from start..end)
    182203      // now order by index
    183       var orderedRandomIndexes = rowIndexes.Take(n).OrderBy(x => x).ToArray();
    184 
    185       // now build a dataset collecting the rows from orderedRandomIndexes into the dataset starting at index start
    186       double[,] reducedData = dataset.GetClonedData();
     204      int[] orderedRandomIndexes =
     205        rowIndexArr.Take(n)
     206        .OrderBy(x => x)
     207        .ToArray();
     208
     209      // now build a dataset containing only rows from orderedRandomIndexes
     210      double[,] reducedData = new double[n, dataset.Columns];
    187211      for (int i = 0; i < n; i++) {
    188212        for (int column = 0; column < dataset.Columns; column++) {
    189           reducedData[start + i, column] = dataset[orderedRandomIndexes[i], column];
     213          reducedData[i, column] = dataset[orderedRandomIndexes[i], column];
    190214        }
    191215      }
     
    198222      double cost, double nu, double gamma, double epsilon,
    199223      int nFolds) {
    200       return PerformCrossValidation(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
     224      return PerformCrossValidation(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
    201225    }
    202226
    203227    public static double PerformCrossValidation(
    204228      DataAnalysisProblemData problemData,
    205       int start, int end,
     229      IEnumerable<int> rowIndices,
    206230      string svmType, string kernelType,
    207231      double cost, double nu, double gamma, double epsilon,
     
    221245
    222246
    223       SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
     247      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, rowIndices);
    224248      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
    225249      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
Note: See TracChangeset for help on using the changeset viewer.