Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/12/11 18:53:51 (13 years ago)
Author:
gkronber
Message:

#1581: implemented caching for SVM models.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/CrossValidation.cs

    r6557 r6566  
    363363
    364364    public void CollectResultValues(IDictionary<string, IItem> results) {
     365      var clonedResults = (ResultCollection)this.results.Clone();
     366      foreach (var result in clonedResults) {
     367        results.Add(result.Name, result.Value);
     368      }
     369    }
     370
     371    private void AggregateResultValues(IDictionary<string, IItem> results) {
    365372      Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
    366373      IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
     
    397404      List<IResult> aggregatedResults = new List<IResult>();
    398405      foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
    399         var problemDataClone = (IRegressionProblemData)Problem.ProblemData.Clone();
     406        // clone manually to correctly clone references between cloned root objects
     407        Cloner cloner = new Cloner();
     408        var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
     409        // set partitions of problem data clone correctly
    400410        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
    401411        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
    402         var ensembleSolution = new RegressionEnsembleSolution(solutions.Value.Select(x => x.Model), problemDataClone,
    403           solutions.Value.Select(x => x.ProblemData.TrainingPartition),
    404           solutions.Value.Select(x => x.ProblemData.TestPartition));
     412        // clone models
     413        var ensembleSolution = new RegressionEnsembleSolution(
     414          solutions.Value.Select(x => cloner.Clone(x.Model)),
     415          problemDataClone,
     416          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
     417          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
    405418
    406419        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
     
    425438      var aggregatedResults = new List<IResult>();
    426439      foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
    427         var problemDataClone = (IClassificationProblemData)Problem.ProblemData.Clone();
     440        // clone manually to correctly clone references between cloned root objects
     441        Cloner cloner = new Cloner();
     442        var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData);
     443        // set partitions of problem data clone correctly
    428444        problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
    429445        problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
    430         var ensembleSolution = new ClassificationEnsembleSolution(solutions.Value.Select(x => x.Model), problemDataClone,
    431           solutions.Value.Select(x => x.ProblemData.TrainingPartition),
    432           solutions.Value.Select(x => x.ProblemData.TestPartition));
     446        // clone models
     447        var ensembleSolution = new ClassificationEnsembleSolution(
     448          solutions.Value.Select(x => cloner.Clone(x.Model)),
     449          problemDataClone,
     450          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
     451          solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
    433452
    434453        aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
     
    702721      stopPending = false;
    703722      Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
    704       CollectResultValues(collectedResults);
     723      AggregateResultValues(collectedResults);
    705724      results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
    706725      runsCounter++;
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineModel.cs

    r5861 r6566  
    9898      this.targetVariable = original.targetVariable;
    9999      this.allowedInputVariables = (string[])original.allowedInputVariables.Clone();
     100      foreach (var dataset in original.cachedPredictions.Keys) {
     101        this.cachedPredictions.Add(cloner.Clone(dataset), (double[])original.cachedPredictions[dataset].Clone());
     102      }
    100103      if (original.classValues != null)
    101104        this.classValues = (double[])original.classValues.Clone();
     
    145148    }
    146149    #endregion
     150    // cache for predictions, which is cloned but not persisted, must be cleared when the model is changed
     151    private Dictionary<Dataset, double[]> cachedPredictions = new Dictionary<Dataset, double[]>();
    147152    private IEnumerable<double> GetEstimatedValuesHelper(Dataset dataset, IEnumerable<int> rows) {
     153      if (!cachedPredictions.ContainsKey(dataset)) {
     154        // create an array of cached predictions which is initially filled with NaNs
     155        double[] predictions = Enumerable.Repeat(double.NaN, dataset.Rows).ToArray();
     156        CalculatePredictions(dataset, rows, predictions);
     157        cachedPredictions.Add(dataset, predictions);
     158      }
     159      // get the array of predictions and select the subset of requested rows
     160      double[] p = cachedPredictions[dataset];
     161      var requestedPredictions = from r in rows
     162                                 select p[r];
     163      // check if the requested predictions contain NaNs
     164      // (this means for the request rows some predictions have not been cached)
     165      if (requestedPredictions.Any(x => double.IsNaN(x))) {
     166        // updated the predictions for currently requested rows
     167        CalculatePredictions(dataset, rows, p);
     168        cachedPredictions[dataset] = p;
     169        // now we can be sure that for the current rows all predictions are available
     170        return from r in rows
     171               select p[r];
     172      } else {
     173        // there were no NaNs => just return the cached predictions
     174        return requestedPredictions;
     175      }
     176    }
     177
     178    private void CalculatePredictions(Dataset dataset, IEnumerable<int> rows, double[] predictions) {
     179      // calculate and cache predictions for the currently requested rows
    148180      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
    149181      SVM.Problem scaledProblem = Scaling.Scale(RangeTransform, problem);
    150182
    151       foreach (var row in Enumerable.Range(0, scaledProblem.Count)) {
    152         yield return SVM.Prediction.Predict(Model, scaledProblem.X[row]);
    153       }
    154     }
     183      // row is the index in the original dataset,
     184      // i is the index in the scaled dataset (containing only the necessary rows)
     185      int i = 0;
     186      foreach (var row in rows) {
     187        predictions[row] = SVM.Prediction.Predict(Model, scaledProblem.X[i]);
     188        i++;
     189      }
     190    }
     191
    155192    #region events
    156193    public event EventHandler Changed;
    157194    private void OnChanged(EventArgs e) {
     195      cachedPredictions.Clear();
    158196      var handlers = Changed;
    159197      if (handlers != null)
Note: See TracChangeset for help on using the changeset viewer.