Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/02/16 09:02:09 (8 years ago)
Author:
gkronber
Message:

#2590 merged r13697:13698, r13700:13702, r13704:13705, r13711, r13715 from trunk to stable

Location:
stable
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis

  • stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionEnsembleSolution.cs

    r13049 r13976  
    7979        }
    8080      }
     81
     82      RegisterModelEvents();
    8183      RegisterRegressionSolutionsEventHandler();
    8284    }
     
    9395      }
    9496
     97      evaluationCache = new Dictionary<int, double>(original.ProblemData.Dataset.Rows);
    9598      trainingEvaluationCache = new Dictionary<int, double>(original.ProblemData.TrainingIndices.Count());
    9699      testEvaluationCache = new Dictionary<int, double>(original.ProblemData.TestIndices.Count());
    97100
    98101      regressionSolutions = cloner.Clone(original.regressionSolutions);
     102      RegisterModelEvents();
    99103      RegisterRegressionSolutionsEventHandler();
    100104    }
     
    106110      regressionSolutions = new ItemCollection<IRegressionSolution>();
    107111
     112      RegisterModelEvents();
    108113      RegisterRegressionSolutionsEventHandler();
    109114    }
    110115
    111116    public RegressionEnsembleSolution(IRegressionProblemData problemData)
    112       : this(Enumerable.Empty<IRegressionModel>(), problemData) {
    113     }
    114 
    115     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData)
    116       : this(models, problemData,
    117              models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
    118              models.Select(m => (IntRange)problemData.TestPartition.Clone())
    119       ) { }
    120 
    121     public RegressionEnsembleSolution(IEnumerable<IRegressionModel> models, IRegressionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
    122       : base(new RegressionEnsembleModel(Enumerable.Empty<IRegressionModel>()), new RegressionEnsembleProblemData(problemData)) {
    123       this.trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
    124       this.testPartitions = new Dictionary<IRegressionModel, IntRange>();
    125       this.regressionSolutions = new ItemCollection<IRegressionSolution>();
    126 
    127       List<IRegressionSolution> solutions = new List<IRegressionSolution>();
    128       var modelEnumerator = models.GetEnumerator();
    129       var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
    130       var testPartitionEnumerator = testPartitions.GetEnumerator();
    131 
    132       while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
    133         var p = (IRegressionProblemData)problemData.Clone();
    134         p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
    135         p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
    136         p.TestPartition.Start = testPartitionEnumerator.Current.Start;
    137         p.TestPartition.End = testPartitionEnumerator.Current.End;
    138 
    139         solutions.Add(modelEnumerator.Current.CreateRegressionSolution(p));
    140       }
    141       if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
    142         throw new ArgumentException();
    143       }
    144 
     117      : this(new RegressionEnsembleModel(), problemData) {
     118    }
     119
     120    public RegressionEnsembleSolution(IRegressionEnsembleModel model, IRegressionProblemData problemData)
     121      : base(model, new RegressionEnsembleProblemData(problemData)) {
     122      trainingPartitions = new Dictionary<IRegressionModel, IntRange>();
     123      testPartitions = new Dictionary<IRegressionModel, IntRange>();
     124      regressionSolutions = new ItemCollection<IRegressionSolution>();
     125
     126      evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows);
    145127      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
    146128      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());
    147129
     130
     131      var solutions = model.Models.Select(m => m.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()));
     132      foreach (var solution in solutions) {
     133        regressionSolutions.Add(solution);
     134        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     135        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     136      }
     137
     138      RecalculateResults();
     139      RegisterModelEvents();
    148140      RegisterRegressionSolutionsEventHandler();
    149       regressionSolutions.AddRange(solutions);
    150     }
     141    }
     142
    151143
    152144    public override IDeepCloneable Clone(Cloner cloner) {
    153145      return new RegressionEnsembleSolution(this, cloner);
     146    }
     147
     148    private void RegisterModelEvents() {
     149      Model.Changed += Model_Changed;
    154150    }
    155151    private void RegisterRegressionSolutionsEventHandler() {
     
    168164        var rows = ProblemData.TrainingIndices;
    169165        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
     166
    170167        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    171         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
     168        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
    172169
    173170        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    184181        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
    185182        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    186         var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
     183        var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate, RowIsTestForModel).GetEnumerator();
    187184
    188185        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    193190      }
    194191    }
    195 
    196     private IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows, Func<int, IRegressionModel, bool> modelSelectionPredicate) {
    197       var estimatedValuesEnumerators = (from model in Model.Models
    198                                         select new { Model = model, EstimatedValuesEnumerator = model.GetEstimatedValues(ProblemData.Dataset, rows).GetEnumerator() })
    199                                        .ToList();
    200       var rowsEnumerator = rows.GetEnumerator();
    201       // aggregate to make sure that MoveNext is called for all enumerators
    202       while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
    203         int currentRow = rowsEnumerator.Current;
    204 
    205         var selectedEnumerators = from pair in estimatedValuesEnumerators
    206                                   where modelSelectionPredicate(currentRow, pair.Model)
    207                                   select pair.EstimatedValuesEnumerator;
    208 
    209         yield return AggregateEstimatedValues(selectedEnumerators.Select(x => x.Current));
    210       }
    211     }
    212 
    213192    private bool RowIsTrainingForModel(int currentRow, IRegressionModel model) {
    214193      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
    215194              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
    216195    }
    217 
    218196    private bool RowIsTestForModel(int currentRow, IRegressionModel model) {
    219197      return testPartitions == null || !testPartitions.ContainsKey(model) ||
     
    224202      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
    225203      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
    226       var valuesEnumerator = (from xs in GetEstimatedValueVectors(ProblemData.Dataset, rowsToEvaluate)
    227                               select AggregateEstimatedValues(xs))
    228                              .GetEnumerator();
     204      var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator();
    229205
    230206      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
     
    235211    }
    236212
    237     public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IDataset dataset, IEnumerable<int> rows) {
    238       if (!Model.Models.Any()) yield break;
    239       var estimatedValuesEnumerators = (from model in Model.Models
    240                                         select model.GetEstimatedValues(dataset, rows).GetEnumerator())
    241                                        .ToList();
    242 
    243       while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
    244         yield return from enumerator in estimatedValuesEnumerators
    245                      select enumerator.Current;
    246       }
    247     }
    248 
    249     private double AggregateEstimatedValues(IEnumerable<double> estimatedValues) {
    250       return estimatedValues.DefaultIfEmpty(double.NaN).Average();
     213    public IEnumerable<IEnumerable<double>> GetEstimatedValueVectors(IEnumerable<int> rows) {
     214      return Model.GetEstimatedValueVectors(ProblemData.Dataset, rows);
    251215    }
    252216    #endregion
     
    282246    }
    283247
    284     public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    285       regressionSolutions.AddRange(solutions);
     248    private void Model_Changed(object sender, EventArgs e) {
     249      var modelSet = new HashSet<IRegressionModel>(Model.Models);
     250      foreach (var model in Model.Models) {
     251        if (!trainingPartitions.ContainsKey(model)) trainingPartitions.Add(model, ProblemData.TrainingPartition);
     252        if (!testPartitions.ContainsKey(model)) testPartitions.Add(model, ProblemData.TrainingPartition);
     253      }
     254      foreach (var model in trainingPartitions.Keys) {
     255        if (modelSet.Contains(model)) continue;
     256        trainingPartitions.Remove(model);
     257        testPartitions.Remove(model);
     258      }
    286259
    287260      trainingEvaluationCache.Clear();
    288261      testEvaluationCache.Clear();
    289262      evaluationCache.Clear();
     263
     264      OnModelChanged();
     265    }
     266
     267    public void AddRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
     268      regressionSolutions.AddRange(solutions);
    290269    }
    291270    public void RemoveRegressionSolutions(IEnumerable<IRegressionSolution> solutions) {
    292271      regressionSolutions.RemoveRange(solutions);
    293 
    294       trainingEvaluationCache.Clear();
    295       testEvaluationCache.Clear();
    296       evaluationCache.Clear();
    297272    }
    298273
    299274    private void regressionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    300       foreach (var solution in e.Items) AddRegressionSolution(solution);
    301       RecalculateResults();
     275      foreach (var solution in e.Items) {
     276        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     277        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     278      }
     279      Model.AddRange(e.Items.Select(s => s.Model));
    302280    }
    303281    private void regressionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    304       foreach (var solution in e.Items) RemoveRegressionSolution(solution);
    305       RecalculateResults();
     282      foreach (var solution in e.Items) {
     283        trainingPartitions.Remove(solution.Model);
     284        testPartitions.Remove(solution.Model);
     285      }
     286      Model.RemoveRange(e.Items.Select(s => s.Model));
    306287    }
    307288    private void regressionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IRegressionSolution> e) {
    308       foreach (var solution in e.OldItems) RemoveRegressionSolution(solution);
    309       foreach (var solution in e.Items) AddRegressionSolution(solution);
    310       RecalculateResults();
    311     }
    312 
    313     private void AddRegressionSolution(IRegressionSolution solution) {
    314       if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
    315       Model.Add(solution.Model);
    316       trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
    317       testPartitions[solution.Model] = solution.ProblemData.TestPartition;
    318 
    319       trainingEvaluationCache.Clear();
    320       testEvaluationCache.Clear();
    321       evaluationCache.Clear();
    322     }
    323 
    324     private void RemoveRegressionSolution(IRegressionSolution solution) {
    325       if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
    326       Model.Remove(solution.Model);
    327       trainingPartitions.Remove(solution.Model);
    328       testPartitions.Remove(solution.Model);
    329 
    330       trainingEvaluationCache.Clear();
    331       testEvaluationCache.Clear();
    332       evaluationCache.Clear();
     289      foreach (var solution in e.OldItems) {
     290        trainingPartitions.Remove(solution.Model);
     291        testPartitions.Remove(solution.Model);
     292      }
     293      Model.RemoveRange(e.OldItems.Select(s => s.Model));
     294
     295      foreach (var solution in e.Items) {
     296        trainingPartitions.Add(solution.Model, solution.ProblemData.TrainingPartition);
     297        testPartitions.Add(solution.Model, solution.ProblemData.TestPartition);
     298      }
     299      Model.AddRange(e.Items.Select(s => s.Model));
    333300    }
    334301  }
Note: See TracChangeset for help on using the changeset viewer.