Free cookie consent management tool by TermsFeed Policy Generator

source: branches/LearningClassifierSystems/HeuristicLab.Encodings.ConditionActionEncoding/3.3/Ensemble/ConditionActionEnsembleSolution.cs @ 14821

Last change on this file since 14821 was 9411, checked in by sforsten, 12 years ago

#1980:

  • added multiple discretizer to GAssist
  • created ensembles for LCS problems and edited CrossValidation to use them
File size: 19.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using HeuristicLab.Collections;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Encodings.ConditionActionEncoding {
34  [StorableClass]
35  [Item("ConditionActionEnsembleSolution", "")]
36  [Creatable("Data Analysis - Ensembles")]
37  public class ConditionActionEnsembleSolution : ResultCollection, IConditionActionEnsembleSolution {
38    private readonly Dictionary<int, IAction> trainingEvaluationCache = new Dictionary<int, IAction>();
39    private readonly Dictionary<int, IAction> testEvaluationCache = new Dictionary<int, IAction>();
40    private readonly Dictionary<int, IAction> evaluationCache = new Dictionary<int, IAction>();
41
42    private const string ModelResultName = "Model";
43    private const string ProblemDataResultName = "ProblemData";
44    private const string TrainingAccuracyResultName = "Accuracy (training)";
45    private const string TestAccuracyResultName = "Accuracy (test)";
46
47    public string Filename { get; set; }
48
49    public static new Image StaticItemImage {
50      get { return HeuristicLab.Common.Resources.VSImageLibrary.Function; }
51    }
52
53    public double TrainingAccuracy {
54      get { return ((PercentValue)this[TrainingAccuracyResultName].Value).Value; }
55      private set { ((PercentValue)this[TrainingAccuracyResultName].Value).Value = value; }
56    }
57    public double TestAccuracy {
58      get { return ((PercentValue)this[TestAccuracyResultName].Value).Value; }
59      private set { ((PercentValue)this[TestAccuracyResultName].Value).Value = value; }
60    }
61
62    #region properties
63    public IConditionActionEnsembleModel Model {
64      get { return (IConditionActionEnsembleModel)this[ModelResultName].Value; }
65      protected set {
66        if (this[ModelResultName].Value != value) {
67          if (value != null) {
68            value.ClassifierComparer = ProblemData != null ? ProblemData.ClassifierComparer : null;
69            this[ModelResultName].Value = value;
70            OnModelChanged();
71          }
72        }
73      }
74    }
75
76    public IConditionActionEnsembleProblemData ProblemData {
77      get { return (IConditionActionEnsembleProblemData)this[ProblemDataResultName].Value; }
78      set {
79        if (this[ProblemDataResultName].Value != value) {
80          if (value != null) {
81            ProblemData.Changed -= new EventHandler(ProblemData_Changed);
82            this[ProblemDataResultName].Value = value;
83            ProblemData.Changed += new EventHandler(ProblemData_Changed);
84            OnProblemDataChanged();
85          }
86        }
87      }
88    }
89
90    private void ProblemData_Changed(object sender, EventArgs e) {
91      OnProblemDataChanged();
92    }
93    #endregion
94
95    private readonly ItemCollection<IConditionActionSolution> conditionActionSolutions;
96    public IItemCollection<IConditionActionSolution> ConditionActionSolutions {
97      get { return conditionActionSolutions; }
98    }
99
100    [Storable]
101    private Dictionary<IConditionActionModel, IntRange> trainingPartitions;
102    [Storable]
103    private Dictionary<IConditionActionModel, IntRange> testPartitions;
104
105    [StorableHook(HookType.AfterDeserialization)]
106    private void AfterDeserialization() {
107      Model.ClassifierComparer = ProblemData.ClassifierComparer;
108      foreach (var model in Model.Models) {
109        IConditionActionProblemData problemData = ProblemData.GetConditionActionProblemData();
110        problemData.TrainingPartition.Start = trainingPartitions[model].Start;
111        problemData.TrainingPartition.End = trainingPartitions[model].End;
112        problemData.TestPartition.Start = testPartitions[model].Start;
113        problemData.TestPartition.End = testPartitions[model].End;
114
115        conditionActionSolutions.Add(model.CreateConditionActionSolution(problemData));
116      }
117      RegisterConditionActionSolutionsEventHandler();
118    }
119
120    [StorableConstructor]
121    protected ConditionActionEnsembleSolution(bool deserializing)
122      : base(deserializing) {
123      conditionActionSolutions = new ItemCollection<IConditionActionSolution>();
124    }
125    protected ConditionActionEnsembleSolution(ConditionActionEnsembleSolution original, Cloner cloner)
126      : base(original, cloner) {
127      trainingPartitions = new Dictionary<IConditionActionModel, IntRange>();
128      testPartitions = new Dictionary<IConditionActionModel, IntRange>();
129      foreach (var pair in original.trainingPartitions) {
130        trainingPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
131      }
132      foreach (var pair in original.testPartitions) {
133        testPartitions[cloner.Clone(pair.Key)] = cloner.Clone(pair.Value);
134      }
135
136      trainingEvaluationCache = new Dictionary<int, IAction>(original.ProblemData.TrainingIndices.Count());
137      testEvaluationCache = new Dictionary<int, IAction>(original.ProblemData.TestIndices.Count());
138
139      conditionActionSolutions = cloner.Clone(original.conditionActionSolutions);
140      RegisterConditionActionSolutionsEventHandler();
141    }
142
143    public ConditionActionEnsembleSolution(IEnumerable<IConditionActionModel> models, IConditionActionProblemData problemData)
144      : this(models, problemData,
145             models.Select(m => (IntRange)problemData.TrainingPartition.Clone()),
146             models.Select(m => (IntRange)problemData.TestPartition.Clone())
147      ) { }
148
149    public ConditionActionEnsembleSolution()
150      : base() {
151      trainingPartitions = new Dictionary<IConditionActionModel, IntRange>();
152      testPartitions = new Dictionary<IConditionActionModel, IntRange>();
153      conditionActionSolutions = new ItemCollection<IConditionActionSolution>();
154
155      RegisterConditionActionSolutionsEventHandler();
156    }
157
158    public ConditionActionEnsembleSolution(IConditionActionProblemData problemData)
159      : this(Enumerable.Empty<IConditionActionModel>(), problemData) {
160    }
161
162    public ConditionActionEnsembleSolution(IEnumerable<IConditionActionModel> models, IConditionActionProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
163      : base() {
164      Add(new Result(ModelResultName, "The data analysis model.", new ConditionActionEnsembleModel(Enumerable.Empty<IConditionActionModel>())));
165      Add(new Result(ProblemDataResultName, "The data analysis problem data.", new ConditionActionEnsembleProblemData((IConditionActionProblemData)problemData.Clone())));
166      Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
167      Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
168
169      Model.ClassifierComparer = ProblemData.ClassifierComparer;
170
171      this.trainingPartitions = new Dictionary<IConditionActionModel, IntRange>();
172      this.testPartitions = new Dictionary<IConditionActionModel, IntRange>();
173      this.conditionActionSolutions = new ItemCollection<IConditionActionSolution>();
174
175      List<IConditionActionSolution> solutions = new List<IConditionActionSolution>();
176      var modelEnumerator = models.GetEnumerator();
177      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
178      var testPartitionEnumerator = testPartitions.GetEnumerator();
179
180      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
181        var p = (IConditionActionProblemData)problemData.Clone();
182        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
183        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
184        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
185        p.TestPartition.End = testPartitionEnumerator.Current.End;
186
187        solutions.Add(modelEnumerator.Current.CreateConditionActionSolution(p));
188      }
189      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
190        throw new ArgumentException();
191      }
192
193      trainingEvaluationCache = new Dictionary<int, IAction>(problemData.TrainingIndices.Count());
194      testEvaluationCache = new Dictionary<int, IAction>(problemData.TestIndices.Count());
195
196      RegisterConditionActionSolutionsEventHandler();
197      conditionActionSolutions.AddRange(solutions);
198    }
199
200    public override IDeepCloneable Clone(Cloner cloner) {
201      return new ConditionActionEnsembleSolution(this, cloner);
202    }
203
204    private void RegisterConditionActionSolutionsEventHandler() {
205      conditionActionSolutions.ItemsAdded += new CollectionItemsChangedEventHandler<IConditionActionSolution>(conditionActionSolutions_ItemsAdded);
206      conditionActionSolutions.ItemsRemoved += new CollectionItemsChangedEventHandler<IConditionActionSolution>(conditionActionSolutions_ItemsRemoved);
207      conditionActionSolutions.CollectionReset += new CollectionItemsChangedEventHandler<IConditionActionSolution>(conditionActionSolutions_CollectionReset);
208    }
209
210    #region Evaluation
211    public IEnumerable<IAction> EstimatedAction {
212      get { return GetEstimatedNiches(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
213    }
214
215    public IEnumerable<IAction> EstimatedTrainingAction {
216      get {
217        var rows = ProblemData.TrainingIndices;
218        var rowsToEvaluate = rows.Except(trainingEvaluationCache.Keys);
219        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
220        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, (r, m) => RowIsTrainingForModel(r, m) && !RowIsTestForModel(r, m)).GetEnumerator();
221
222        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
223          trainingEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
224        }
225
226        return rows.Select(row => trainingEvaluationCache[row]);
227      }
228    }
229
230    public IEnumerable<IAction> EstimatedTestAction {
231      get {
232        var rows = ProblemData.TestIndices;
233        var rowsToEvaluate = rows.Except(testEvaluationCache.Keys);
234        var rowsEnumerator = rowsToEvaluate.GetEnumerator();
235        var valuesEnumerator = GetEstimatedValues(rowsToEvaluate, RowIsTestForModel).GetEnumerator();
236
237        while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
238          testEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
239        }
240
241        return rows.Select(row => testEvaluationCache[row]);
242      }
243    }
244
245    public IEnumerable<IAction> GetEstimatedNiches(IEnumerable<int> rows) {
246      var rowsToEvaluate = rows.Except(evaluationCache.Keys);
247      var rowsEnumerator = rowsToEvaluate.GetEnumerator();
248
249      var valuesEnumerator = (from xs in GetEstimatedNicheVectors(ProblemData.FetchInput(rows))
250                              select AggregateEstimatedClassValues(xs))
251                              .GetEnumerator();
252
253      while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) {
254        evaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current);
255      }
256
257      return rows.Select(row => evaluationCache[row]);
258    }
259
260    public IEnumerable<IEnumerable<IAction>> GetEstimatedNicheVectors(IEnumerable<IInput> input) {
261      if (!Model.Models.Any()) yield break;
262      var estimatedValuesEnumerators = (from model in Model.Models
263                                        select model.GetAction(input).GetEnumerator())
264                                       .ToList();
265
266      while (estimatedValuesEnumerators.All(en => en.MoveNext())) {
267        yield return from enumerator in estimatedValuesEnumerators
268                     select enumerator.Current;
269      }
270    }
271
272    private IEnumerable<IAction> GetEstimatedValues(IEnumerable<int> rows, Func<int, IConditionActionModel, bool> modelSelectionPredicate) {
273      var input = ProblemData.FetchInput(rows);
274      var estimatedValuesEnumerators = (from model in Model.Models
275                                        select new { Model = model, EstimatedValuesEnumerator = model.GetAction(input).GetEnumerator() })
276                                       .ToList();
277      var rowsEnumerator = rows.GetEnumerator();
278      // aggregate to make sure that MoveNext is called for all enumerators
279      while (rowsEnumerator.MoveNext() & estimatedValuesEnumerators.Select(en => en.EstimatedValuesEnumerator.MoveNext()).Aggregate(true, (acc, b) => acc & b)) {
280        int currentRow = rowsEnumerator.Current;
281
282        var selectedEnumerators = from pair in estimatedValuesEnumerators
283                                  where modelSelectionPredicate(currentRow, pair.Model)
284                                  select pair.EstimatedValuesEnumerator;
285
286        yield return AggregateEstimatedClassValues(selectedEnumerators.Select(x => x.Current));
287      }
288    }
289
290    private IAction AggregateEstimatedClassValues(IEnumerable<IAction> estimatedNiches) {
291      return estimatedNiches
292      .GroupBy(x => x, ProblemData.ClassifierComparer)
293      .OrderByDescending(g => g.Count())
294      .Select(g => g.Key)
295      .FirstOrDefault();
296    }
297
298    private void RecalculateResults() {
299      var originalTrainingCondition = ProblemData.FetchInput(ProblemData.TrainingIndices);
300      var originalTestCondition = ProblemData.FetchInput(ProblemData.TestIndices);
301      var estimatedTraining = EstimatedTrainingAction;
302      var estimatedTest = EstimatedTestAction;
303
304      var originalTrainingAction = ProblemData.FetchAction(ProblemData.TrainingIndices);
305      var originalTestAction = ProblemData.FetchAction(ProblemData.TestIndices);
306
307      TrainingAccuracy = CalculateAccuracy(originalTrainingAction, estimatedTraining);
308      TestAccuracy = CalculateAccuracy(originalTestAction, estimatedTest);
309    }
310
311    public static double CalculateAccuracy(IEnumerable<IAction> original, IEnumerable<IAction> estimated) {
312      double correctClassified = 0;
313
314      double rows = original.Count();
315      var originalEnumerator = original.GetEnumerator();
316      var estimatedActionEnumerator = estimated.GetEnumerator();
317
318      while (originalEnumerator.MoveNext() && estimatedActionEnumerator.MoveNext()) {
319        if (originalEnumerator.Current.Match(estimatedActionEnumerator.Current)) {
320          correctClassified++;
321        }
322      }
323
324      return correctClassified / rows;
325    }
326
327    private bool RowIsTrainingForModel(int currentRow, IConditionActionModel model) {
328      return trainingPartitions == null || !trainingPartitions.ContainsKey(model) ||
329              (trainingPartitions[model].Start <= currentRow && currentRow < trainingPartitions[model].End);
330    }
331
332    private bool RowIsTestForModel(int currentRow, IConditionActionModel model) {
333      return testPartitions == null || !testPartitions.ContainsKey(model) ||
334              (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
335    }
336    #endregion
337
338    public event EventHandler ProblemDataChanged;
339    protected void OnProblemDataChanged() {
340      trainingEvaluationCache.Clear();
341      testEvaluationCache.Clear();
342      evaluationCache.Clear();
343
344      IConditionActionProblemData problemData = ProblemData.GetConditionActionProblemData();
345
346      problemData.TrainingPartition.Start = ProblemData.TrainingPartition.Start;
347      problemData.TrainingPartition.End = ProblemData.TrainingPartition.End;
348      problemData.TestPartition.Start = ProblemData.TestPartition.Start;
349      problemData.TestPartition.End = ProblemData.TestPartition.End;
350
351      foreach (var solution in conditionActionSolutions) {
352        if (solution is ConditionActionEnsembleSolution)
353          solution.ProblemData = ProblemData;
354        else
355          solution.ProblemData = problemData;
356      }
357      foreach (var trainingPartition in trainingPartitions.Values) {
358        trainingPartition.Start = ProblemData.TrainingPartition.Start;
359        trainingPartition.End = ProblemData.TrainingPartition.End;
360      }
361      foreach (var testPartition in testPartitions.Values) {
362        testPartition.Start = ProblemData.TestPartition.Start;
363        testPartition.End = ProblemData.TestPartition.End;
364      }
365
366      RecalculateResults();
367      var listeners = ProblemDataChanged;
368      if (listeners != null) listeners(this, EventArgs.Empty);
369    }
370
371    public event EventHandler ModelChanged;
372    protected virtual void OnModelChanged() {
373      RecalculateResults();
374      var listeners = ModelChanged;
375      if (listeners != null) listeners(this, EventArgs.Empty);
376    }
377
378    public void AddConditionActionSolutions(IEnumerable<IConditionActionSolution> solutions) {
379      conditionActionSolutions.AddRange(solutions);
380
381      trainingEvaluationCache.Clear();
382      testEvaluationCache.Clear();
383      evaluationCache.Clear();
384    }
385    public void RemoveConditionActionSolutions(IEnumerable<IConditionActionSolution> solutions) {
386      conditionActionSolutions.RemoveRange(solutions);
387
388      trainingEvaluationCache.Clear();
389      testEvaluationCache.Clear();
390      evaluationCache.Clear();
391    }
392
393    private void conditionActionSolutions_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IConditionActionSolution> e) {
394      foreach (var solution in e.Items) AddConditionActionSolution(solution);
395      RecalculateResults();
396    }
397    private void conditionActionSolutions_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IConditionActionSolution> e) {
398      foreach (var solution in e.Items) RemoveConditionActionSolution(solution);
399      RecalculateResults();
400    }
401    private void conditionActionSolutions_CollectionReset(object sender, CollectionItemsChangedEventArgs<IConditionActionSolution> e) {
402      foreach (var solution in e.OldItems) RemoveConditionActionSolution(solution);
403      foreach (var solution in e.Items) AddConditionActionSolution(solution);
404      RecalculateResults();
405    }
406
407    private void AddConditionActionSolution(IConditionActionSolution solution) {
408      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
409      Model.Add(solution.Model);
410      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
411      testPartitions[solution.Model] = solution.ProblemData.TestPartition;
412
413      trainingEvaluationCache.Clear();
414      testEvaluationCache.Clear();
415      evaluationCache.Clear();
416    }
417
418    private void RemoveConditionActionSolution(IConditionActionSolution solution) {
419      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
420      Model.Remove(solution.Model);
421      trainingPartitions.Remove(solution.Model);
422      testPartitions.Remove(solution.Model);
423
424      trainingEvaluationCache.Clear();
425      testEvaluationCache.Clear();
426      evaluationCache.Clear();
427    }
428
429    #region IConditionActionSolution Members
430    IConditionActionModel IConditionActionSolution.Model {
431      get { return Model; }
432    }
433    IConditionActionProblemData IConditionActionSolution.ProblemData {
434      get { return ProblemData; }
435      set { ProblemData = new ConditionActionEnsembleProblemData(value); }
436    }
437    #endregion
438  }
439}
Note: See TracBrowser for help on using the repository browser.