Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/01/10 12:31:04 (14 years ago)
Author:
gkronber
Message:

Adapted SVM classes to work correctly for overlapping training / test partitions. #1226

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/DataAnalysisProblemData.cs

    r4473 r4543  
    180180      get {
    181181        return Enumerable.Range(TrainingSamplesStart.Value, TrainingSamplesEnd.Value - TrainingSamplesStart.Value)
    182                          .Where(i => i > 0 && i < Dataset.Rows && (i < TestSamplesStart.Value || TestSamplesEnd.Value <= i));
     182                         .Where(i => i >= 0 && i < Dataset.Rows && (i < TestSamplesStart.Value || TestSamplesEnd.Value <= i));
    183183      }
    184184    }
     
    186186      get {
    187187        return Enumerable.Range(TestSamplesStart.Value, TestSamplesEnd.Value - TestSamplesStart.Value)
    188            .Where(i => i > 0 && i < Dataset.Rows);
     188           .Where(i => i >= 0 && i < Dataset.Rows);
    189189      }
    190190    }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs

    r4068 r4543  
    2929using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3030using SVM;
     31using System.Collections.Generic;
    3132
    3233namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     
    143144
    144145    public override IOperation Apply() {
    145       double reductionRatio = 1.0;
     146      double reductionRatio = 1.0; // TODO: make parameter
    146147      if (ActualSamplesParameter.ActualValue != null)
    147148        reductionRatio = ActualSamplesParameter.ActualValue.Value;
    148 
    149       int reducedRows = (int)((SamplesEnd.Value - SamplesStart.Value) * reductionRatio);
     149      IEnumerable<int> rows =
     150        Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)
     151        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
     152
     153      // create a new DataAnalysisProblemData instance
    150154      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
    151       reducedProblemData.Dataset = CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, reductionRatio, SamplesStart.Value, SamplesEnd.Value);
     155      reducedProblemData.Dataset =
     156        CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, rows, reductionRatio);
     157      reducedProblemData.TrainingSamplesStart.Value = 0;
     158      reducedProblemData.TrainingSamplesEnd.Value = reducedProblemData.Dataset.Rows;
     159      reducedProblemData.TestSamplesStart.Value = reducedProblemData.Dataset.Rows;
     160      reducedProblemData.TestSamplesEnd.Value = reducedProblemData.Dataset.Rows;
     161      reducedProblemData.ValidationPercentage.Value = 0;
    152162
    153163      double quality = PerformCrossValidation(reducedProblemData,
    154                              SamplesStart.Value, SamplesStart.Value + reducedRows,
    155164                             SvmType.Value, KernelType.Value,
    156165                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
     
    160169    }
    161170
    162     private Dataset CreateReducedDataset(IRandom random, Dataset dataset, double reductionRatio, int start, int end) {
    163       int n = (int)((end - start) * reductionRatio);
     171    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, IEnumerable<int> rowIndices, double reductionRatio) {
     172     
    164173      // must not make a fink:
    165174      // => select n rows randomly from start..end
     
    168177
    169178      // all possible rowIndexes from start..end
    170       int[] rowIndexes = Enumerable.Range(start, end - start).ToArray();
     179      int[] rowIndexArr = rowIndices.ToArray();
     180      int n = (int)Math.Max(1.0, rowIndexArr.Length * reductionRatio);
    171181
    172182      // knuth shuffle
    173       for (int i = rowIndexes.Length - 1; i > 0; i--) {
     183      for (int i = rowIndexArr.Length - 1; i > 0; i--) {
    174184        int j = random.Next(0, i);
    175185        // swap
    176         int tmp = rowIndexes[i];
    177         rowIndexes[i] = rowIndexes[j];
    178         rowIndexes[j] = tmp;
     186        int tmp = rowIndexArr[i];
     187        rowIndexArr[i] = rowIndexArr[j];
     188        rowIndexArr[j] = tmp;
    179189      }
    180190
    181191      // take the first n indexes (selected n rowIndexes from start..end)
    182192      // now order by index
    183       var orderedRandomIndexes = rowIndexes.Take(n).OrderBy(x => x).ToArray();
    184 
    185       // now build a dataset collecting the rows from orderedRandomIndexes into the dataset starting at index start
    186       double[,] reducedData = dataset.GetClonedData();
     193      int[] orderedRandomIndexes =
     194        rowIndexArr.Take(n)
     195        .OrderBy(x => x)
     196        .ToArray();
     197
     198      // now build a dataset containing only rows from orderedRandomIndexes
     199      double[,] reducedData = new double[n, dataset.Columns];
    187200      for (int i = 0; i < n; i++) {
    188201        for (int column = 0; column < dataset.Columns; column++) {
    189           reducedData[start + i, column] = dataset[orderedRandomIndexes[i], column];
     202          reducedData[i, column] = dataset[orderedRandomIndexes[i], column];
    190203        }
    191204      }
     
    198211      double cost, double nu, double gamma, double epsilon,
    199212      int nFolds) {
    200       return PerformCrossValidation(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
     213      return PerformCrossValidation(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
    201214    }
    202215
    203216    public static double PerformCrossValidation(
    204217      DataAnalysisProblemData problemData,
    205       int start, int end,
     218      IEnumerable<int> rowIndices,
    206219      string svmType, string kernelType,
    207220      double cost, double nu, double gamma, double epsilon,
     
    221234
    222235
    223       SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
     236      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, rowIndices);
    224237      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
    225238      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModel.cs

    r4068 r4543  
    7272
    7373    public IEnumerable<double> GetEstimatedValues(DataAnalysisProblemData problemData, int start, int end) {
    74       SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
     74      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, Enumerable.Range(start, end - start));
    7575      SVM.Problem scaledProblem = Scaling.Scale(RangeTransform, problem);
    7676
    7777      return (from row in Enumerable.Range(0, scaledProblem.Count)
    78               select SVM.Prediction.Predict(Model, scaledProblem.X[row])).ToList();
     78              select SVM.Prediction.Predict(Model, scaledProblem.X[row]))
     79              .ToList();
    7980    }
    8081
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelCreator.cs

    r4068 r4543  
    2727using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828using SVM;
     29using System.Collections.Generic;
     30using System.Linq;
    2931
    3032namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     
    125127
    126128    public override IOperation Apply() {
     129      int start = SamplesStart.Value;
     130      int end = SamplesEnd.Value;
     131      IEnumerable<int> rows =
     132        Enumerable.Range(start, end-start)
     133        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
     134
    127135      SupportVectorMachineModel model = TrainModel(DataAnalysisProblemData,
    128                              SamplesStart.Value, SamplesEnd.Value,
     136                             rows,
    129137                             SvmType.Value, KernelType.Value,
    130138                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value);
     
    138146      string svmType, string kernelType,
    139147      double cost, double nu, double gamma, double epsilon) {
    140       return TrainModel(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon);
     148      return TrainModel(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon);
    141149    }
    142150
    143151    public static SupportVectorMachineModel TrainModel(
    144152      DataAnalysisProblemData problemData,
    145       int start, int end,
     153      IEnumerable<int> trainingIndizes,
    146154      string svmType, string kernelType,
    147155      double cost, double nu, double gamma, double epsilon) {
     
    160168
    161169
    162       SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
     170      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, trainingIndizes);
    163171      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
    164172      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelEvaluator.cs

    r4068 r4543  
    2626using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2727using SVM;
     28using System.Collections.Generic;
     29using System.Linq;
    2830
    2931namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     
    8082      int start = SamplesStart.Value;
    8183      int end = SamplesEnd.Value;
     84      IEnumerable<int> rows =
     85        Enumerable.Range(start, end - start)
     86        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
    8287
    83       ValuesParameter.ActualValue = new DoubleMatrix(Evaluate(SupportVectorMachineModel, DataAnalysisProblemData, start, end));
     88      ValuesParameter.ActualValue = new DoubleMatrix(Evaluate(SupportVectorMachineModel, DataAnalysisProblemData, rows));
    8489      return base.Apply();
    8590    }
    8691
    87     public static double[,] Evaluate(SupportVectorMachineModel model, DataAnalysisProblemData problemData, int start, int end) {
    88       SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
     92    public static double[,] Evaluate(SupportVectorMachineModel model, DataAnalysisProblemData problemData, IEnumerable<int> rowIndices) {
     93      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, rowIndices);
    8994      SVM.Problem scaledProblem = model.RangeTransform.Scale(problem);
    9095
     
    9297
    9398      double[,] values = new double[scaledProblem.Count, 2];
     99      var rowEnumerator = rowIndices.GetEnumerator();
    94100      for (int i = 0; i < scaledProblem.Count; i++) {
    95         values[i, 0] = problemData.Dataset[start + i, targetVariableIndex];
     101        rowEnumerator.MoveNext();
     102        values[i, 0] = problemData.Dataset[rowEnumerator.Current, targetVariableIndex];
    96103        values[i, 1] = SVM.Prediction.Predict(model.Model, scaledProblem.X[i]);
    97104      }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineUtil.cs

    r4068 r4543  
    2929    /// </summary>
    3030    /// <param name="problemData">The problem data to transform</param>
    31     /// <param name="start">The index of the first row of <paramref name="problemData"/> to copy to the output.</param>
    32     /// <param name="end">The last of the first row of <paramref name="problemData"/> to copy to the output.</param>
     31    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
    3332    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
    34     public static SVM.Problem CreateSvmProblem(DataAnalysisProblemData problemData, int start, int end) {
    35       int rowCount = end - start;
    36       var targetVector = problemData.Dataset.GetVariableValues(problemData.TargetVariable.Value, start, end);
     33    public static SVM.Problem CreateSvmProblem(DataAnalysisProblemData problemData, IEnumerable<int> rowIndices) {
     34      double[] targetVector =
     35        problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable.Value, rowIndices)
     36        .ToArray();
    3737
    3838      SVM.Node[][] nodes = new SVM.Node[targetVector.Length][];
    3939      List<SVM.Node> tempRow;
    4040      int maxNodeIndex = 0;
    41       for (int row = 0; row < rowCount; row++) {
     41      int svmProblemRowIndex = 0;
     42      foreach (int row in rowIndices) {
    4243        tempRow = new List<SVM.Node>();
    4344        foreach (var inputVariable in problemData.InputVariables.CheckedItems) {
    4445          int col = problemData.Dataset.GetVariableIndex(inputVariable.Value.Value);
    45           double value = problemData.Dataset[start + row, col];
     46          double value = problemData.Dataset[row, col];
    4647          if (!double.IsNaN(value)) {
    47             int nodeIndex = col + 1; // make sure the smallest nodeIndex = 1
     48            int nodeIndex = col + 1; // make sure the smallest nodeIndex is 1 (libSVM convention)
    4849            tempRow.Add(new SVM.Node(nodeIndex, value));
    4950            if (nodeIndex > maxNodeIndex) maxNodeIndex = nodeIndex;
    5051          }
    5152        }
    52         nodes[row] = tempRow.OrderBy(x => x.Index).ToArray(); // make sure the values are sorted by node index
     53        nodes[svmProblemRowIndex++] = tempRow.OrderBy(x => x.Index).ToArray(); // make sure the values are sorted by node index
    5354      }
    5455
Note: See TracChangeset for help on using the changeset viewer.