Free cookie consent management tool by TermsFeed Policy Generator

source: branches/LearningClassifierSystems/HeuristicLab.Optimization.Operators.LCS/3.3/GAssist/Ensemble/GAssistEnsembleSolution.cs @ 17578

Last change on this file since 17578 was 9411, checked in by sforsten, 12 years ago

#1980:

  • added multiple discretizer to GAssist
  • created ensembles for LCS problems and edited CrossValidation to use them
File size: 19.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Collections;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Optimization.Operators.LCS {
33  [StorableClass]
34  [Item("GAssistEnsembleSolution", "Represents a GAssist ensemble.")]
35  [Creatable("Data Analysis - Ensembles")]
36  public class GAssistEnsembleSolution : ResultCollection, IGAssistEnsembleSolution {
37    private readonly Dictionary<int, IGAssistNiche> trainingEvaluationCache = new Dictionary<int, IGAssistNiche>();
38    private readonly Dictionary<int, IGAssistNiche> testEvaluationCache = new Dictionary<int, IGAssistNiche>();
39    private readonly Dictionary<int, IGAssistNiche> evaluationCache = new Dictionary<int, IGAssistNiche>();
40
41    private const string ModelResultName = "Model";
42    private const string ProblemDataResultName = "ProblemData";
43    private const string TrainingAccuracyResultName = "Accuracy (training)";
44    private const string TestAccuracyResultName = "Accuracy (test)";
45
46    public string Filename { get; set; }
47
48    public static new Image StaticItemImage {
49      get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
50    }
51
52    public double TrainingAccuracy {
53      get { return ((PercentValue)this[TrainingAccuracyResultName].Value).Value; }
54      private set { ((PercentValue)this[TrainingAccuracyResultName].Value).Value = value; }
55    }
56    public double TestAccuracy {
57      get { return ((PercentValue)this[TestAccuracyResultName].Value).Value; }
58      private set { ((PercentValue)this[TestAccuracyResultName].Value).Value = value; }
59    }
60
61    #region properties
62    public IGAssistEnsembleModel Model {
63      get { return (IGAssistEnsembleModel)this[ModelResultName].Value; }
64      protected set {
65        if (this[ModelResultName].Value != value) {
66          if (value != null) {
67            this[ModelResultName].Value = value;
68            OnModelChanged();
69          }
70        }
71      }
72    }
73
74    public IGAssistEnsembleProblemData ProblemData {
75      get { return (IGAssistEnsembleProblemData)this[ProblemDataResultName].Value; }
76      set {
77        if (this[ProblemDataResultName].Value != value) {
78          if (value != null) {
79            ProblemData.Changed -= new EventHandler(ProblemData_Changed);
80            this[ProblemDataResultName].Value = value;
81            ProblemData.Changed += new EventHandler(ProblemData_Changed);
82            OnProblemDataChanged();
83          }
84        }
85      }
86    }
87
88    private void ProblemData_Changed(object sender, EventArgs e) {
89      OnProblemDataChanged();
90    }
91    #endregion
92
93    private readonly ItemCollection<IGAssistSolution> gassistSolutions;
94    public IItemCollection<IGAssistSolution> GAssistSolutions {
95      get { return gassistSolutions; }
96    }
97
98    [Storable]
99    private Dictionary<IGAssistModel, IntRange> trainingPartitions;
100    [Storable]
101    private Dictionary<IGAssistModel, IntRange> testPartitions;
102
103    [StorableHook(HookType.AfterDeserialization)]
104    private void AfterDeserialization() {
105      foreach (var model in Model.Models) {
106        IGAssistProblemData problemData = ProblemData.GetGAssistProblemData();
107        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
108        problemData.TrainingPartition.End = trainingPartitions[model].End;
109        problemData.TestPartition.Start = testPartitions[model].Start;
110        problemData.TestPartition.End = testPartitions[model].End;
111
112        gassistSolutions.Add(model.CreateGAssistSolution(problemData));
113      }
114      RegisterGAssistSolutionsEventHandler();
115    }
116
117    [StorableConstructor]
118    protected GAssistEnsembleSolution(bool deserializing)
119      : base(deserializing) {
120      gassistSolutions = new ItemCollection<IGAssistSolution>();
121    }
122    protected GAssistEnsembleSolution(GAssistEnsembleSolution original, Cloner cloner)
123      : base(original, cloner) {
124      trainingPartitions = new Dictionary<IGAssistModel, IntRange>();
125      testPartitions = new Dictionary<IGAssistModel, IntRange>();
126      foreach (var pair in original.trainingPartitions) {
127        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
128      }
129      foreach (var pair in original.testPartitions) {
130        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
131      }
132
133      trainingEvaluationCache = new Dictionary<int, IGAssistNiche>(original.ProblemData.TrainingIndices.Count());
134      testEvaluationCache = new Dictionary<int, IGAssistNiche>(original.ProblemData.TestIndices.Count());
135
136      gassistSolutions = cloner.Clone(original.gassistSolutions);
137      RegisterGAssistSolutionsEventHandler();
138    }
139
140    public GAssistEnsembleSolution(IEnumerable<IGAssistModel> models, IGAssistProblemData problemData)
141      : this(models, problemData,
142             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
143             models.Select(m => (IntRange)problemData.TestPartition.Clone())
144      ) { }
145
146    public GAssistEnsembleSolution()
147      : base() {
148      trainingPartitions = new Dictionary<IGAssistModel, IntRange>();
149      testPartitions = new Dictionary<IGAssistModel, IntRange>();
150      gassistSolutions = new ItemCollection<IGAssistSolution>();
151
152      RegisterGAssistSolutionsEventHandler();
153    }
154
155    public GAssistEnsembleSolution(IGAssistProblemData problemData)
156      : this(Enumerable.Empty<IGAssistModel>(), problemData) {
157    }
158
159    public GAssistEnsembleSolution(IEnumerable<IGAssistModel> models, IGAssistProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
160      : base() {
161      Add(new Result(ModelResultName, "The data analysis model.", new GAssistEnsembleModel(Enumerable.Empty<IGAssistModel>())));
162      Add(new Result(ProblemDataResultName, "The data analysis problem data.", new GAssistEnsembleProblemData((IGAssistProblemData)problemData.Clone())));
163      Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
164      Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
165
166      this.trainingPartitions = new Dictionary<IGAssistModel, IntRange>();
167      this.testPartitions = new Dictionary<IGAssistModel, IntRange>();
168      this.gassistSolutions = new ItemCollection<IGAssistSolution>();
169
170      List<IGAssistSolution> solutions = new List<IGAssistSolution>();
171      var modelEnumerator = models.GetEnumerator();
172      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
173      var testPartitionEnumerator = testPartitions.GetEnumerator();
174
175      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
176        var p = (IGAssistProblemData)problemData.Clone();
177        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
178        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
179        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
180        p.TestPartition.End = testPartitionEnumerator.Current.End;
181
182        solutions.Add(modelEnumerator.Current.CreateGAssistSolution(p));
183      }
184      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
185        throw new ArgumentException();
186      }
187
188      trainingEvaluationCache = new Dictionary<int, IGAssistNiche>(problemData.TrainingIndices.Count());
189      testEvaluationCache = new Dictionary<int, IGAssistNiche>(problemData.TestIndices.Count());
190
191      RegisterGAssistSolutionsEventHandler();
192      gassistSolutions.AddRange(solutions);
193    }
194
195    public override IDeepCloneable Clone(Cloner cloner) {
196      return new GAssistEnsembleSolution(this, cloner);
197    }
198
199    private void RegisterGAssistSolutionsEventHandler() {
200      gassistSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IGAssistSolution>(gassistSolutions_ItemsAdded);
201      gassistSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IGAssistSolution>(gassistSolutions_ItemsRemoved);
202      gassistSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IGAssistSolution>(gassistSolutions_CollectionReset);
203    }
204
205    #region Evaluation
206    public IEnumerable<IGAssistNiche> EstimatedNiches {
207      get { return GetEstimatedNiches(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
208    }
209
210    public IEnumerable<IGAssistNiche> EstimatedTrainingNiche {
211      get {
212        var rows = ProblemData.TrainingIndices;
213        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
214        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
215        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
216
217        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
218          trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
219        }
220
221        return rows.Select(row => trainingEvaluationCache[row]);
222      }
223    }
224
225    public IEnumerable<IGAssistNiche> EstimatedTestNiche {
226      get {
227        var rows = ProblemData.TestIndices;
228        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
229        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
230        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
231
232        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
233          testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
234        }
235
236        return rows.Select(row => testEvaluationCache[row]);
237      }
238    }
239
240    public IEnumerable<IGAssistNiche> GetEstimatedNiches(IEnumerable<int> rows) {
241      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
242      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
243
244      var valuesEnumerator = (from xs in GetEstimatedNicheVectors(ProblemData.FetchInput(rows))
245                              select AggregateEstimatedClassValues(xs))
246                              .GetEnumerator();
247
248      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
249        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
250      }
251
252      return rows.Select(row => evaluationCache[row]);
253    }
254
255    public IEnumerable<IEnumerable<IGAssistNiche>> GetEstimatedNicheVectors(IEnumerable<IGAssistInput> input) {
256      if (!Model.Models.Any()) yield break;
257      var estimatedValuesEnumerators = (from model in Model.Models
258                                        select model.Evaluate(input).GetEnumerator())
259                                       .ToList();
260
261      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
262        yield return from enumerator in estimatedValuesEnumerators
263                     select enumerator.Current;
264      }
265    }
266
267    private IEnumerable<IGAssistNiche> GetEstimatedValues(IEnumerable<int> rows, Func<int, IGAssistModel, bool> modelSelectionPredicate) {
268      var input = ProblemData.FetchInput(rows);
269      var estimatedValuesEnumerators = (from model in Model.Models
270                                        select new { Model = model, EstimatedValuesEnumerator = model.Evaluate(input).GetEnumerator() })
271                                       .ToList();
272      var rowsEnumerator = rows.GetEnumerator();
273      // aggregate to make sure that MoveNext is called for all enumerators
274      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
275        int currentRow = rowsEnumerator.Current;
276
277        var selectedEnumerators = from pair in estimatedValuesEnumerators
278                                  where modelSelectionPredicate(currentRow, pair.Model)
279                                  select pair.EstimatedValuesEnumerator;
280
281        yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
282      }
283    }
284
285    private IGAssistNiche AggregateEstimatedClassValues(IEnumerable<IGAssistNiche> estimatedNiches) {
286      return estimatedNiches
287      .GroupBy(x => x, new GAssistNicheComparer())
288      .OrderByDescending(g => g.Count())
289      .Select(g => g.Key)
290      .FirstOrDefault();
291    }
292
293    private void RecalculateResults() {
294      var originalTrainingCondition = ProblemData.FetchInput(ProblemData.TrainingIndices);
295      var originalTestCondition = ProblemData.FetchInput(ProblemData.TestIndices);
296      var estimatedTraining = EstimatedTrainingNiche;
297      var estimatedTest = EstimatedTestNiche;
298
299      var originalTrainingAction = ProblemData.FetchAction(ProblemData.TrainingIndices);
300      var originalTestAction = ProblemData.FetchAction(ProblemData.TestIndices);
301
302      TrainingAccuracy = CalculateAccuracy(originalTrainingAction, estimatedTraining);
303      TestAccuracy = CalculateAccuracy(originalTestAction, estimatedTest);
304    }
305
306    public static double CalculateAccuracy(IEnumerable<IGAssistNiche> original, IEnumerable<IGAssistNiche> estimated) {
307      double correctClassified = 0;
308
309      double rows = original.Count();
310      var originalEnumerator = original.GetEnumerator();
311      var estimatedActionEnumerator = estimated.GetEnumerator();
312
313      while (originalEnumerator.MoveNext() && estimatedActionEnumerator.MoveNext()) {
314        if (originalEnumerator.Current != null && estimatedActionEnumerator.Current != null
315          && originalEnumerator.Current.SameNiche(estimatedActionEnumerator.Current)) {
316          correctClassified++;
317        }
318      }
319      return correctClassified / rows;
320    }
321
322    private bool RowIsTrainingForModel(int currentRow, IGAssistModel model) {
323      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
324              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
325    }
326
327    private bool RowIsTestForModel(int currentRow, IGAssistModel model) {
328      return testPartitions == null || !testPartitions.ContainsKey(model) ||
329              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
330    }
331    #endregion
332
333    public event EventHandler ProblemDataChanged;
334    protected void OnProblemDataChanged() {
335      trainingEvaluationCache.Clear();
336      testEvaluationCache.Clear();
337      evaluationCache.Clear();
338
339      IGAssistProblemData problemData = ProblemData.GetGAssistProblemData();
340
341      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
342      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
343      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
344      problemData.TestPartition.End = ProblemData.TestPartition.End;
345
346      foreach (var solution in GAssistSolutions) {
347        if (solution is GAssistEnsembleSolution)
348          solution.ProblemData = ProblemData;
349        else
350          solution.ProblemData = problemData;
351      }
352      foreach (var trainingPartition in trainingPartitions.Values) {
353        trainingPartition.Start = ProblemData.TrainingPartition.Start;
354        trainingPartition.End = ProblemData.TrainingPartition.End;
355      }
356      foreach (var testPartition in testPartitions.Values) {
357        testPartition.Start = ProblemData.TestPartition.Start;
358        testPartition.End = ProblemData.TestPartition.End;
359      }
360
361      RecalculateResults();
362      var listeners = ProblemDataChanged;
363      if (listeners != null) listeners(this, EventArgs.Empty);
364    }
365
366    public event EventHandler ModelChanged;
367    protected virtual void OnModelChanged() {
368      RecalculateResults();
369      var listeners = ModelChanged;
370      if (listeners != null) listeners(this, EventArgs.Empty);
371    }
372
373    public void AddGAssistSolutions(IEnumerable<IGAssistSolution> solutions) {
374      gassistSolutions.AddRange(solutions);
375
376      trainingEvaluationCache.Clear();
377      testEvaluationCache.Clear();
378      evaluationCache.Clear();
379    }
380    public void RemoveGAssistSolutions(IEnumerable<IGAssistSolution> solutions) {
381      gassistSolutions.RemoveRange(solutions);
382
383      trainingEvaluationCache.Clear();
384      testEvaluationCache.Clear();
385      evaluationCache.Clear();
386    }
387
388    private void gassistSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IGAssistSolution> e) {
389      foreach (var solution in e.Items) AddGAssistSolution(solution);
390      RecalculateResults();
391    }
392    private void gassistSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IGAssistSolution> e) {
393      foreach (var solution in e.Items) RemoveGAssistSolution(solution);
394      RecalculateResults();
395    }
396    private void gassistSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IGAssistSolution> e) {
397      foreach (var solution in e.OldItems) RemoveGAssistSolution(solution);
398      foreach (var solution in e.Items) AddGAssistSolution(solution);
399      RecalculateResults();
400    }
401
402    private void AddGAssistSolution(IGAssistSolution solution) {
403      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
404      Model.Add(solution.Model);
405      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
406      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
407
408      trainingEvaluationCache.Clear();
409      testEvaluationCache.Clear();
410      evaluationCache.Clear();
411    }
412
413    private void RemoveGAssistSolution(IGAssistSolution solution) {
414      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
415      Model.Remove(solution.Model);
416      trainingPartitions.Remove(solution.Model);
417      testPartitions.Remove(solution.Model);
418
419      trainingEvaluationCache.Clear();
420      testEvaluationCache.Clear();
421      evaluationCache.Clear();
422    }
423
424    #region IGAssistSolution Members
425    IGAssistModel IGAssistSolution.Model {
426      get { return Model; }
427    }
428    IGAssistProblemData IGAssistSolution.ProblemData {
429      get { return ProblemData; }
430      set { ProblemData = new GAssistEnsembleProblemData(value); }
431    }
432    public int TrainingNumberOfAliveRules {
433      get { return gassistSolutions.Sum(x => x.TrainingNumberOfAliveRules); }
434    }
435    public double TrainingTheoryLength {
436      get { return gassistSolutions.Sum(x => x.TrainingTheoryLength); }
437    }
438    public double TrainingExceptionsLength {
439      get { return 105.0 - TrainingAccuracy * 100.0; }
440    }
441    public int Classes {
442      get { return ProblemData.Classes; }
443    }
444    #endregion
445  }
446}
Note: See TracBrowser for help on using the repository browser.