Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/22/11 14:57:24 (13 years ago)
Author:
gkronber
Message:

#1552: implemented first version of an optimizer for regression analysis experiments

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
1 edited
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r6583 r6587  
    107107  </ItemGroup>
    108108  <ItemGroup>
     109    <Compile Include="RegressionWorkbench.cs" />
    109110    <Compile Include="CrossValidation.cs">
    110111      <SubType>Code</SubType>
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RegressionWorkbench.cs

    r6584 r6587  
    3232using HeuristicLab.Problems.DataAnalysis;
    3333using HeuristicLab.Problems.DataAnalysis.Symbolic;
     34using HeuristicLab.Parameters;
    3435
    3536namespace HeuristicLab.Algorithms.DataAnalysis {
    36   [Item("Cross Validation", "Cross-validation wrapper for data analysis algorithms.")]
     37  [Item("Regression Workbench", "Experiment containing multiple algorithms for regression analysis.")]
    3738  [Creatable("Data Analysis")]
    3839  [StorableClass]
    39   public sealed class CrossValidation : ParameterizedNamedItem, IAlgorithm, IStorableContent {
    40     public CrossValidation()
     40  public sealed class RegressionWorkbench : ParameterizedNamedItem, IOptimizer, IStorableContent {
     41    public string Filename { get; set; }
     42
     43    private const string ProblemDataParameterName = "ProblemData";
     44
     45    #region parameter properties
     46    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
     47      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
     48    }
     49    #endregion
     50    #region properties
     51    public IRegressionProblemData ProblemData {
     52      get { return ProblemDataParameter.Value; }
     53      set { ProblemDataParameter.Value = value; }
     54    }
     55    #endregion
     56    [Storable]
     57    private Experiment experiment;
     58
     59    [StorableConstructor]
     60    private RegressionWorkbench(bool deserializing)
     61      : base(deserializing) {
     62    }
     63    private RegressionWorkbench(RegressionWorkbench original, Cloner cloner)
     64      : base(original, cloner) {
     65      experiment = cloner.Clone(original.experiment);
     66      RegisterEventHandlers();
     67    }
     68    public RegressionWorkbench()
    4169      : base() {
    4270      name = ItemName;
    4371      description = ItemDescription;
    4472
    45       executionState = ExecutionState.Stopped;
    46       runs = new RunCollection();
    47       runsCounter = 0;
    48 
    49       algorithm = null;
    50       clonedAlgorithms = new ItemCollection<IAlgorithm>();
    51       results = new ResultCollection();
    52 
    53       folds = new IntValue(2);
    54       numberOfWorkers = new IntValue(1);
    55       samplesStart = new IntValue(0);
    56       samplesEnd = new IntValue(0);
    57       storeAlgorithmInEachRun = false;
    58 
    59       RegisterEvents();
    60       if (Algorithm != null) RegisterAlgorithmEvents();
    61     }
    62 
    63     public string Filename { get; set; }
    64 
    65     #region persistence and cloning
    66     [StorableConstructor]
    67     private CrossValidation(bool deserializing)
    68       : base(deserializing) {
    69     }
     73      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The regression problem data that should be used for modeling.", new RegressionProblemData()));
     74
     75      experiment = new Experiment();
     76
     77      //var svmExperiments = CreateSvmExperiment();
     78      var rfExperiments = CreateRandomForestExperiments();
     79
     80      experiment.Optimizers.Add(new LinearRegression());
     81      experiment.Optimizers.Add(rfExperiments);
     82      //experiment.Optimizers.Add(svmExperiments);
     83
     84      RegisterEventHandlers();
     85    }
     86
    7087    [StorableHook(HookType.AfterDeserialization)]
    7188    private void AfterDeserialization() {
    72       RegisterEvents();
    73       if (Algorithm != null) RegisterAlgorithmEvents();
    74     }
    75 
    76     private CrossValidation(CrossValidation original, Cloner cloner)
    77       : base(original, cloner) {
    78       executionState = original.executionState;
    79       storeAlgorithmInEachRun = original.storeAlgorithmInEachRun;
    80       runs = cloner.Clone(original.runs);
    81       runsCounter = original.runsCounter;
    82       algorithm = cloner.Clone(original.algorithm);
    83       clonedAlgorithms = cloner.Clone(original.clonedAlgorithms);
    84       results = cloner.Clone(original.results);
    85 
    86       folds = cloner.Clone(original.folds);
    87       numberOfWorkers = cloner.Clone(original.numberOfWorkers);
    88       samplesStart = cloner.Clone(original.samplesStart);
    89       samplesEnd = cloner.Clone(original.samplesEnd);
    90       RegisterEvents();
    91       if (Algorithm != null) RegisterAlgorithmEvents();
    92     }
     89      RegisterEventHandlers();
     90    }
     91
    9392    public override IDeepCloneable Clone(Cloner cloner) {
    94       return new CrossValidation(this, cloner);
    95     }
    96 
    97     #endregion
    98 
    99     #region properties
    100     [Storable]
    101     private IAlgorithm algorithm;
    102     public IAlgorithm Algorithm {
    103       get { return algorithm; }
    104       set {
    105         if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
    106           throw new InvalidOperationException("Changing the algorithm is only allowed if the CrossValidation is stopped or prepared.");
    107         if (algorithm != value) {
    108           if (value != null && value.Problem != null && !(value.Problem is IDataAnalysisProblem))
    109             throw new ArgumentException("Only algorithms with a DataAnalysisProblem could be used for the cross validation.");
    110           if (algorithm != null) DeregisterAlgorithmEvents();
    111           algorithm = value;
    112           Parameters.Clear();
    113 
    114           if (algorithm != null) {
    115             algorithm.StoreAlgorithmInEachRun = false;
    116             RegisterAlgorithmEvents();
    117             algorithm.Prepare(true);
    118             Parameters.AddRange(algorithm.Parameters);
     93      return new RegressionWorkbench(this, cloner);
     94    }
     95
     96    private void RegisterEventHandlers() {
     97      ProblemDataParameter.ValueChanged += ProblemDataParameterValueChanged;
     98
     99      experiment.ExceptionOccurred += (sender, e) => OnExceptionOccured(e);
     100      experiment.ExecutionStateChanged += (sender, e) => OnExecutionStateChanged(e);
     101      experiment.ExecutionTimeChanged += (sender, e) => OnExecutionTimeChanged(e);
     102      experiment.Paused += (sender, e) => OnPaused(e);
     103      experiment.Prepared += (sender, e) => OnPrepared(e);
     104      experiment.Started += (sender, e) => OnStarted(e);
     105      experiment.Stopped += (sender, e) => OnStopped(e);
     106    }
     107
     108    private IOptimizer CreateRandomForestExperiments() {
     109      var exp = new Experiment();
     110      double[] rs = new double[] { 0.2, 0.3, 0.4, 0.5, 0.6, 0.65 };
     111      foreach (var r in rs) {
     112        var cv = new CrossValidation();
     113        var rf = new RandomForestRegression();
     114        rf.R = r;
     115        cv.Algorithm = rf;
     116        cv.Folds.Value = 5;
     117        exp.Optimizers.Add(cv);
     118      }
     119      return exp;
     120    }
     121
     122    private IOptimizer CreateSvmExperiment() {
     123      var exp = new Experiment();
     124      var costs = new double[] { Math.Pow(2, -5), Math.Pow(2, -3), Math.Pow(2, -1), Math.Pow(2, 1), Math.Pow(2, 3), Math.Pow(2, 5), Math.Pow(2, 7), Math.Pow(2, 9), Math.Pow(2, 11), Math.Pow(2, 13), Math.Pow(2, 15) };
     125      var gammas = new double[] { Math.Pow(2, -15), Math.Pow(2, -13), Math.Pow(2, -11), Math.Pow(2, -9), Math.Pow(2, -7), Math.Pow(2, -5), Math.Pow(2, -3), Math.Pow(2, -1), Math.Pow(2, 1), Math.Pow(2, 3) };
     126      var nus = new double[] { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 };
     127      foreach (var gamma in gammas)
     128        foreach (var cost in costs)
     129          foreach (var nu in nus) {
     130            var cv = new CrossValidation();
     131            var svr = new SupportVectorRegression();       
     132            svr.Nu.Value = nu;
     133            svr.Cost.Value = cost;
     134            svr.Gamma.Value = gamma;
     135            cv.Algorithm = svr;
     136            cv.Folds.Value = 5;
     137            exp.Optimizers.Add(cv);
    119138          }
    120           OnAlgorithmChanged();
    121           if (algorithm != null) OnProblemChanged();
    122           Prepare();
    123         }
     139      return exp;
     140    }
     141
     142    public RunCollection Runs {
     143      get { return experiment.Runs; }
     144    }
     145
     146    public void Prepare(bool clearRuns) {
     147      experiment.Prepare(clearRuns);
     148    }
     149
     150    public IEnumerable<IOptimizer> NestedOptimizers {
     151      get { return experiment.NestedOptimizers; }
     152    }
     153
     154    public ExecutionState ExecutionState {
     155      get { return experiment.ExecutionState; }
     156    }
     157
     158    public TimeSpan ExecutionTime {
     159      get { return experiment.ExecutionTime; }
     160    }
     161
     162    public void Prepare() {
     163      experiment.Prepare();
     164    }
     165
     166    public void Start() {
     167      experiment.Start();
     168    }
     169
     170    public void Pause() {
     171      experiment.Pause();
     172    }
     173
     174    public void Stop() {
     175      experiment.Stop();
     176    }
     177
     178    public event EventHandler ExecutionStateChanged;
     179    private void OnExecutionStateChanged(EventArgs e) {
     180      var handler = ExecutionStateChanged;
     181      if (handler != null) handler(this, e);
     182    }
     183
     184    public event EventHandler ExecutionTimeChanged;
     185    private void OnExecutionTimeChanged(EventArgs e) {
     186      var handler = ExecutionTimeChanged;
     187      if (handler != null) handler(this, e);
     188    }
     189
     190    public event EventHandler Prepared;
     191    private void OnPrepared(EventArgs e) {
     192      var handler = Prepared;
     193      if (handler != null) handler(this, e);
     194    }
     195
     196    public event EventHandler Started;
     197    private void OnStarted(EventArgs e) {
     198      var handler = Started;
     199      if (handler != null) handler(this, e);
     200    }
     201
     202    public event EventHandler Paused;
     203    private void OnPaused(EventArgs e) {
     204      var handler = Paused;
     205      if (handler != null) handler(this, e);
     206    }
     207
     208    public event EventHandler Stopped;
     209    private void OnStopped(EventArgs e) {
     210      var handler = Stopped;
     211      if (handler != null) handler(this, e);
     212    }
     213
     214    public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
     215    private void OnExceptionOccured(EventArgs<Exception> e) {
     216      var handler = ExceptionOccurred;
     217      if (handler != null) handler(this, e);
     218    }
     219
     220    public void ProblemDataParameterValueChanged(object source, EventArgs e) {
     221      foreach (var op in NestedOptimizers.OfType<IDataAnalysisAlgorithm<IRegressionProblem>>()) {
     222        op.Problem.ProblemDataParameter.Value = ProblemData;
    124223      }
    125224    }
    126 
    127 
    128     [Storable]
    129     private IDataAnalysisProblem problem;
    130     public IDataAnalysisProblem Problem {
    131       get {
    132         if (algorithm == null)
    133           return null;
    134         return (IDataAnalysisProblem)algorithm.Problem;
    135       }
    136       set {
    137         if (ExecutionState != ExecutionState.Prepared && ExecutionState != ExecutionState.Stopped)
    138           throw new InvalidOperationException("Changing the problem is only allowed if the CrossValidation is stopped or prepared.");
    139         if (algorithm == null) throw new ArgumentNullException("Could not set a problem before an algorithm was set.");
    140         algorithm.Problem = value;
    141         problem = value;
    142       }
    143     }
    144 
    145     IProblem IAlgorithm.Problem {
    146       get { return Problem; }
    147       set {
    148         if (value != null && !ProblemType.IsInstanceOfType(value))
    149           throw new ArgumentException("Only DataAnalysisProblems could be used for the cross validation.");
    150         Problem = (IDataAnalysisProblem)value;
    151       }
    152     }
    153     public Type ProblemType {
    154       get { return typeof(IDataAnalysisProblem); }
    155     }
    156 
    157     [Storable]
    158     private ItemCollection<IAlgorithm> clonedAlgorithms;
    159 
    160     public IEnumerable<IOptimizer> NestedOptimizers {
    161       get {
    162         if (Algorithm == null) yield break;
    163         yield return Algorithm;
    164       }
    165     }
    166 
    167     [Storable]
    168     private ResultCollection results;
    169     public ResultCollection Results {
    170       get { return results; }
    171     }
    172 
    173     [Storable]
    174     private IntValue folds;
    175     public IntValue Folds {
    176       get { return folds; }
    177     }
    178     [Storable]
    179     private IntValue samplesStart;
    180     public IntValue SamplesStart {
    181       get { return samplesStart; }
    182     }
    183     [Storable]
    184     private IntValue samplesEnd;
    185     public IntValue SamplesEnd {
    186       get { return samplesEnd; }
    187     }
    188     [Storable]
    189     private IntValue numberOfWorkers;
    190     public IntValue NumberOfWorkers {
    191       get { return numberOfWorkers; }
    192     }
    193 
    194     [Storable]
    195     private bool storeAlgorithmInEachRun;
    196     public bool StoreAlgorithmInEachRun {
    197       get { return storeAlgorithmInEachRun; }
    198       set {
    199         if (storeAlgorithmInEachRun != value) {
    200           storeAlgorithmInEachRun = value;
    201           OnStoreAlgorithmInEachRunChanged();
    202         }
    203       }
    204     }
    205 
    206     [Storable]
    207     private int runsCounter;
    208     [Storable]
    209     private RunCollection runs;
    210     public RunCollection Runs {
    211       get { return runs; }
    212     }
    213 
    214     [Storable]
    215     private ExecutionState executionState;
    216     public ExecutionState ExecutionState {
    217       get { return executionState; }
    218       private set {
    219         if (executionState != value) {
    220           executionState = value;
    221           OnExecutionStateChanged();
    222           OnItemImageChanged();
    223         }
    224       }
    225     }
    226     public override Image ItemImage {
    227       get {
    228         if (ExecutionState == ExecutionState.Prepared) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePrepared;
    229         else if (ExecutionState == ExecutionState.Started) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStarted;
    230         else if (ExecutionState == ExecutionState.Paused) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutablePaused;
    231         else if (ExecutionState == ExecutionState.Stopped) return HeuristicLab.Common.Resources.VSImageLibrary.ExecutableStopped;
    232         else return HeuristicLab.Common.Resources.VSImageLibrary.Event;
    233       }
    234     }
    235 
    236     public TimeSpan ExecutionTime {
    237       get {
    238         if (ExecutionState != ExecutionState.Prepared)
    239           return TimeSpan.FromMilliseconds(clonedAlgorithms.Select(x => x.ExecutionTime.TotalMilliseconds).Sum());
    240         return TimeSpan.Zero;
    241       }
    242     }
    243     #endregion
    244 
    245     public void Prepare() {
    246       if (ExecutionState == ExecutionState.Started)
    247         throw new InvalidOperationException(string.Format("Prepare not allowed in execution state \"{0}\".", ExecutionState));
    248       results.Clear();
    249       clonedAlgorithms.Clear();
    250       if (Algorithm != null) {
    251         Algorithm.Prepare();
    252         if (Algorithm.ExecutionState == ExecutionState.Prepared) OnPrepared();
    253       }
    254     }
    255     public void Prepare(bool clearRuns) {
    256       if (clearRuns) runs.Clear();
    257       Prepare();
    258     }
    259 
    260     private bool startPending;
    261     public void Start() {
    262       if ((ExecutionState != ExecutionState.Prepared) && (ExecutionState != ExecutionState.Paused))
    263         throw new InvalidOperationException(string.Format("Start not allowed in execution state \"{0}\".", ExecutionState));
    264 
    265       if (Algorithm != null && !startPending) {
    266         startPending = true;
    267         //create cloned algorithms
    268         if (clonedAlgorithms.Count == 0) {
    269           int testSamplesCount = (SamplesEnd.Value - SamplesStart.Value) / Folds.Value;
    270 
    271           for (int i = 0; i < Folds.Value; i++) {
    272             IAlgorithm clonedAlgorithm = (IAlgorithm)algorithm.Clone();
    273             clonedAlgorithm.Name = algorithm.Name + " Fold " + i;
    274             IDataAnalysisProblem problem = clonedAlgorithm.Problem as IDataAnalysisProblem;
    275             ISymbolicDataAnalysisProblem symbolicProblem = problem as ISymbolicDataAnalysisProblem;
    276 
    277             int testStart = (i * testSamplesCount) + SamplesStart.Value;
    278             int testEnd = (i + 1) == Folds.Value ? SamplesEnd.Value : (i + 1) * testSamplesCount + SamplesStart.Value;
    279 
    280             problem.ProblemData.TestPartition.Start = testStart;
    281             problem.ProblemData.TestPartition.End = testEnd;
    282             DataAnalysisProblemData problemData = problem.ProblemData as DataAnalysisProblemData;
    283             if (problemData != null) {
    284               problemData.TrainingPartitionParameter.Hidden = false;
    285               problemData.TestPartitionParameter.Hidden = false;
    286             }
    287 
    288             if (symbolicProblem != null) {
    289               symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
    290               symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
    291             }
    292 
    293             clonedAlgorithms.Add(clonedAlgorithm);
    294           }
    295         }
    296 
    297         //start prepared or paused cloned algorithms
    298         int startedAlgorithms = 0;
    299         foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
    300           if (startedAlgorithms < NumberOfWorkers.Value) {
    301             if (clonedAlgorithm.ExecutionState == ExecutionState.Prepared ||
    302                 clonedAlgorithm.ExecutionState == ExecutionState.Paused) {
    303               clonedAlgorithm.Start();
    304               startedAlgorithms++;
    305             }
    306           }
    307         }
    308         OnStarted();
    309       }
    310     }
    311 
    312     private bool pausePending;
    313     public void Pause() {
    314       if (ExecutionState != ExecutionState.Started)
    315         throw new InvalidOperationException(string.Format("Pause not allowed in execution state \"{0}\".", ExecutionState));
    316       if (!pausePending) {
    317         pausePending = true;
    318         if (!startPending) PauseAllClonedAlgorithms();
    319       }
    320     }
    321     private void PauseAllClonedAlgorithms() {
    322       foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
    323         if (clonedAlgorithm.ExecutionState == ExecutionState.Started)
    324           clonedAlgorithm.Pause();
    325       }
    326     }
    327 
    328     private bool stopPending;
    329     public void Stop() {
    330       if ((ExecutionState != ExecutionState.Started) && (ExecutionState != ExecutionState.Paused))
    331         throw new InvalidOperationException(string.Format("Stop not allowed in execution state \"{0}\".",
    332                                                           ExecutionState));
    333       if (!stopPending) {
    334         stopPending = true;
    335         if (!startPending) StopAllClonedAlgorithms();
    336       }
    337     }
    338     private void StopAllClonedAlgorithms() {
    339       foreach (IAlgorithm clonedAlgorithm in clonedAlgorithms) {
    340         if (clonedAlgorithm.ExecutionState == ExecutionState.Started ||
    341             clonedAlgorithm.ExecutionState == ExecutionState.Paused)
    342           clonedAlgorithm.Stop();
    343       }
    344     }
    345 
    346     #region collect parameters and results
    347     public override void CollectParameterValues(IDictionary<string, IItem> values) {
    348       values.Add("Algorithm Name", new StringValue(Name));
    349       values.Add("Algorithm Type", new StringValue(GetType().GetPrettyName()));
    350       values.Add("Folds", new IntValue(Folds.Value));
    351 
    352       if (algorithm != null) {
    353         values.Add("CrossValidation Algorithm Name", new StringValue(Algorithm.Name));
    354         values.Add("CrossValidation Algorithm Type", new StringValue(Algorithm.GetType().GetPrettyName()));
    355         base.CollectParameterValues(values);
    356       }
    357       if (Problem != null) {
    358         values.Add("Problem Name", new StringValue(Problem.Name));
    359         values.Add("Problem Type", new StringValue(Problem.GetType().GetPrettyName()));
    360         Problem.CollectParameterValues(values);
    361       }
    362     }
    363 
    364     public void CollectResultValues(IDictionary<string, IItem> results) {
    365       var clonedResults = (ResultCollection)this.results.Clone();
    366       foreach (var result in clonedResults) {
    367         results.Add(result.Name, result.Value);
    368       }
    369     }
    370 
    371     private void AggregateResultValues(IDictionary<string, IItem> results) {
    372       Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
    373       IEnumerable<IRun> runs = clonedAlgorithms.Select(alg => alg.Runs.FirstOrDefault()).Where(run => run != null);
    374       IEnumerable<KeyValuePair<string, IItem>> resultCollections = runs.Where(x => x != null).SelectMany(x => x.Results).ToList();
    375 
    376       foreach (IResult result in ExtractAndAggregateResults<IntValue>(resultCollections))
    377         results.Add(result.Name, result.Value);
    378       foreach (IResult result in ExtractAndAggregateResults<DoubleValue>(resultCollections))
    379         results.Add(result.Name, result.Value);
    380       foreach (IResult result in ExtractAndAggregateResults<PercentValue>(resultCollections))
    381         results.Add(result.Name, result.Value);
    382       foreach (IResult result in ExtractAndAggregateRegressionSolutions(resultCollections)) {
    383         results.Add(result.Name, result.Value);
    384       }
    385       foreach (IResult result in ExtractAndAggregateClassificationSolutions(resultCollections)) {
    386         results.Add(result.Name, result.Value);
    387       }
    388       results.Add("Execution Time", new TimeSpanValue(this.ExecutionTime));
    389       results.Add("CrossValidation Folds", new RunCollection(runs));
    390     }
    391 
    392     private IEnumerable<IResult> ExtractAndAggregateRegressionSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
    393       Dictionary<string, List<IRegressionSolution>> resultSolutions = new Dictionary<string, List<IRegressionSolution>>();
    394       foreach (var result in resultCollections) {
    395         var regressionSolution = result.Value as IRegressionSolution;
    396         if (regressionSolution != null) {
    397           if (resultSolutions.ContainsKey(result.Key)) {
    398             resultSolutions[result.Key].Add(regressionSolution);
    399           } else {
    400             resultSolutions.Add(result.Key, new List<IRegressionSolution>() { regressionSolution });
    401           }
    402         }
    403       }
    404       List<IResult> aggregatedResults = new List<IResult>();
    405       foreach (KeyValuePair<string, List<IRegressionSolution>> solutions in resultSolutions) {
    406         // clone manually to correctly clone references between cloned root objects
    407         Cloner cloner = new Cloner();
    408         var problemDataClone = (IRegressionProblemData)cloner.Clone(Problem.ProblemData);
    409         // set partitions of problem data clone correctly
    410         problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
    411         problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
    412         // clone models
    413         var ensembleSolution = new RegressionEnsembleSolution(
    414           solutions.Value.Select(x => cloner.Clone(x.Model)),
    415           problemDataClone,
    416           solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
    417           solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
    418 
    419         aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
    420       }
    421       List<IResult> flattenedResults = new List<IResult>();
    422       CollectResultsRecursively("", aggregatedResults, flattenedResults);
    423       return flattenedResults;
    424     }
    425 
    426     private IEnumerable<IResult> ExtractAndAggregateClassificationSolutions(IEnumerable<KeyValuePair<string, IItem>> resultCollections) {
    427       Dictionary<string, List<IClassificationSolution>> resultSolutions = new Dictionary<string, List<IClassificationSolution>>();
    428       foreach (var result in resultCollections) {
    429         var classificationSolution = result.Value as IClassificationSolution;
    430         if (classificationSolution != null) {
    431           if (resultSolutions.ContainsKey(result.Key)) {
    432             resultSolutions[result.Key].Add(classificationSolution);
    433           } else {
    434             resultSolutions.Add(result.Key, new List<IClassificationSolution>() { classificationSolution });
    435           }
    436         }
    437       }
    438       var aggregatedResults = new List<IResult>();
    439       foreach (KeyValuePair<string, List<IClassificationSolution>> solutions in resultSolutions) {
    440         // clone manually to correctly clone references between cloned root objects
    441         Cloner cloner = new Cloner();
    442         var problemDataClone = (IClassificationProblemData)cloner.Clone(Problem.ProblemData);
    443         // set partitions of problem data clone correctly
    444         problemDataClone.TrainingPartition.Start = SamplesStart.Value; problemDataClone.TrainingPartition.End = SamplesEnd.Value;
    445         problemDataClone.TestPartition.Start = SamplesStart.Value; problemDataClone.TestPartition.End = SamplesEnd.Value;
    446         // clone models
    447         var ensembleSolution = new ClassificationEnsembleSolution(
    448           solutions.Value.Select(x => cloner.Clone(x.Model)),
    449           problemDataClone,
    450           solutions.Value.Select(x => cloner.Clone(x.ProblemData.TrainingPartition)),
    451           solutions.Value.Select(x => cloner.Clone(x.ProblemData.TestPartition)));
    452 
    453         aggregatedResults.Add(new Result(solutions.Key + " (ensemble)", ensembleSolution));
    454       }
    455       List<IResult> flattenedResults = new List<IResult>();
    456       CollectResultsRecursively("", aggregatedResults, flattenedResults);
    457       return flattenedResults;
    458     }
    459 
    460     private void CollectResultsRecursively(string path, IEnumerable<IResult> results, IList<IResult> flattenedResults) {
    461       foreach (IResult result in results) {
    462         flattenedResults.Add(new Result(path + result.Name, result.Value));
    463         ResultCollection childCollection = result.Value as ResultCollection;
    464         if (childCollection != null) {
    465           CollectResultsRecursively(path + result.Name + ".", childCollection, flattenedResults);
    466         }
    467       }
    468     }
    469 
    470     private static IEnumerable<IResult> ExtractAndAggregateResults<T>(IEnumerable<KeyValuePair<string, IItem>> results)
    471   where T : class, IItem, new() {
    472       Dictionary<string, List<double>> resultValues = new Dictionary<string, List<double>>();
    473       foreach (var resultValue in results.Where(r => r.Value.GetType() == typeof(T))) {
    474         if (!resultValues.ContainsKey(resultValue.Key))
    475           resultValues[resultValue.Key] = new List<double>();
    476         resultValues[resultValue.Key].Add(ConvertToDouble(resultValue.Value));
    477       }
    478 
    479       DoubleValue doubleValue;
    480       if (typeof(T) == typeof(PercentValue))
    481         doubleValue = new PercentValue();
    482       else if (typeof(T) == typeof(DoubleValue))
    483         doubleValue = new DoubleValue();
    484       else if (typeof(T) == typeof(IntValue))
    485         doubleValue = new DoubleValue();
    486       else
    487         throw new NotSupportedException();
    488 
    489       List<IResult> aggregatedResults = new List<IResult>();
    490       foreach (KeyValuePair<string, List<double>> resultValue in resultValues) {
    491         doubleValue.Value = resultValue.Value.Average();
    492         aggregatedResults.Add(new Result(resultValue.Key + " (average)", (IItem)doubleValue.Clone()));
    493         doubleValue.Value = resultValue.Value.StandardDeviation();
    494         aggregatedResults.Add(new Result(resultValue.Key + " (std.dev.)", (IItem)doubleValue.Clone()));
    495       }
    496       return aggregatedResults;
    497     }
    498 
    499     private static double ConvertToDouble(IItem item) {
    500       if (item is DoubleValue) return ((DoubleValue)item).Value;
    501       else if (item is IntValue) return ((IntValue)item).Value;
    502       else throw new NotSupportedException("Could not convert any item type to double");
    503     }
    504     #endregion
    505 
    506     #region events
    507     private void RegisterEvents() {
    508       Folds.ValueChanged += new EventHandler(Folds_ValueChanged);
    509       SamplesStart.ValueChanged += new EventHandler(SamplesStart_ValueChanged);
    510       SamplesEnd.ValueChanged += new EventHandler(SamplesEnd_ValueChanged);
    511       RegisterClonedAlgorithmsEvents();
    512     }
    513     private void Folds_ValueChanged(object sender, EventArgs e) {
    514       if (ExecutionState != ExecutionState.Prepared)
    515         throw new InvalidOperationException("Can not change number of folds if the execution state is not prepared.");
    516     }
    517     private void SamplesStart_ValueChanged(object sender, EventArgs e) {
    518       if (Problem != null) Problem.ProblemData.TrainingPartition.Start = SamplesStart.Value;
    519     }
    520     private void SamplesEnd_ValueChanged(object sender, EventArgs e) {
    521       if (Problem != null) Problem.ProblemData.TrainingPartition.End = SamplesEnd.Value;
    522     }
    523 
    524     #region template algorithms events
    525     public event EventHandler AlgorithmChanged;
    526     private void OnAlgorithmChanged() {
    527       EventHandler handler = AlgorithmChanged;
    528       if (handler != null) handler(this, EventArgs.Empty);
    529       OnProblemChanged();
    530       if (Problem == null) ExecutionState = ExecutionState.Stopped;
    531     }
    532     private void RegisterAlgorithmEvents() {
    533       algorithm.ProblemChanged += new EventHandler(Algorithm_ProblemChanged);
    534       algorithm.ExecutionStateChanged += new EventHandler(Algorithm_ExecutionStateChanged);
    535     }
    536     private void DeregisterAlgorithmEvents() {
    537       algorithm.ProblemChanged -= new EventHandler(Algorithm_ProblemChanged);
    538       algorithm.ExecutionStateChanged -= new EventHandler(Algorithm_ExecutionStateChanged);
    539     }
    540     private void Algorithm_ProblemChanged(object sender, EventArgs e) {
    541       if (algorithm.Problem != null && !(algorithm.Problem is IDataAnalysisProblem)) {
    542         algorithm.Problem = problem;
    543         throw new ArgumentException("A cross validation algorithm can only contain DataAnalysisProblems.");
    544       }
    545       problem = (IDataAnalysisProblem)algorithm.Problem;
    546       OnProblemChanged();
    547     }
    548     public event EventHandler ProblemChanged;
    549     private void OnProblemChanged() {
    550       EventHandler handler = ProblemChanged;
    551       if (handler != null) handler(this, EventArgs.Empty);
    552 
    553       SamplesStart.Value = 0;
    554       if (Problem != null) {
    555         Problem.ProblemDataChanged += (object sender, EventArgs e) => OnProblemChanged();
    556         SamplesEnd.Value = Problem.ProblemData.Dataset.Rows;
    557 
    558         DataAnalysisProblemData problemData = Problem.ProblemData as DataAnalysisProblemData;
    559         if (problemData != null) {
    560           problemData.TrainingPartitionParameter.Hidden = true;
    561           problemData.TestPartitionParameter.Hidden = true;
    562         }
    563         ISymbolicDataAnalysisProblem symbolicProblem = Problem as ISymbolicDataAnalysisProblem;
    564         if (symbolicProblem != null) {
    565           symbolicProblem.FitnessCalculationPartitionParameter.Hidden = true;
    566           symbolicProblem.FitnessCalculationPartition.Start = SamplesStart.Value;
    567           symbolicProblem.FitnessCalculationPartition.End = SamplesEnd.Value;
    568           symbolicProblem.ValidationPartitionParameter.Hidden = true;
    569           symbolicProblem.ValidationPartition.Start = 0;
    570           symbolicProblem.ValidationPartition.End = 0;
    571         }
    572       } else
    573         SamplesEnd.Value = 0;
    574 
    575       SamplesStart_ValueChanged(this, EventArgs.Empty);
    576       SamplesEnd_ValueChanged(this, EventArgs.Empty);
    577     }
    578 
    579     private void Algorithm_ExecutionStateChanged(object sender, EventArgs e) {
    580       switch (Algorithm.ExecutionState) {
    581         case ExecutionState.Prepared: OnPrepared();
    582           break;
    583         case ExecutionState.Started: throw new InvalidOperationException("Algorithm template can not be started.");
    584         case ExecutionState.Paused: throw new InvalidOperationException("Algorithm template can not be paused.");
    585         case ExecutionState.Stopped: OnStopped();
    586           break;
    587       }
    588     }
    589     #endregion
    590 
    591     #region clonedAlgorithms events
    592     private void RegisterClonedAlgorithmsEvents() {
    593       clonedAlgorithms.ItemsAdded += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
    594       clonedAlgorithms.ItemsRemoved += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
    595       clonedAlgorithms.CollectionReset += new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
    596       foreach (IAlgorithm algorithm in clonedAlgorithms)
    597         RegisterClonedAlgorithmEvents(algorithm);
    598     }
    599     private void DeregisterClonedAlgorithmsEvents() {
    600       clonedAlgorithms.ItemsAdded -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsAdded);
    601       clonedAlgorithms.ItemsRemoved -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_ItemsRemoved);
    602       clonedAlgorithms.CollectionReset -= new CollectionItemsChangedEventHandler<IAlgorithm>(ClonedAlgorithms_CollectionReset);
    603       foreach (IAlgorithm algorithm in clonedAlgorithms)
    604         DeregisterClonedAlgorithmEvents(algorithm);
    605     }
    606     private void ClonedAlgorithms_ItemsAdded(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
    607       foreach (IAlgorithm algorithm in e.Items)
    608         RegisterClonedAlgorithmEvents(algorithm);
    609     }
    610     private void ClonedAlgorithms_ItemsRemoved(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
    611       foreach (IAlgorithm algorithm in e.Items)
    612         DeregisterClonedAlgorithmEvents(algorithm);
    613     }
    614     private void ClonedAlgorithms_CollectionReset(object sender, CollectionItemsChangedEventArgs<IAlgorithm> e) {
    615       foreach (IAlgorithm algorithm in e.OldItems)
    616         DeregisterClonedAlgorithmEvents(algorithm);
    617       foreach (IAlgorithm algorithm in e.Items)
    618         RegisterClonedAlgorithmEvents(algorithm);
    619     }
    620     private void RegisterClonedAlgorithmEvents(IAlgorithm algorithm) {
    621       algorithm.ExceptionOccurred += new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
    622       algorithm.ExecutionTimeChanged += new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
    623       algorithm.Started += new EventHandler(ClonedAlgorithm_Started);
    624       algorithm.Paused += new EventHandler(ClonedAlgorithm_Paused);
    625       algorithm.Stopped += new EventHandler(ClonedAlgorithm_Stopped);
    626     }
    627     private void DeregisterClonedAlgorithmEvents(IAlgorithm algorithm) {
    628       algorithm.ExceptionOccurred -= new EventHandler<EventArgs<Exception>>(ClonedAlgorithm_ExceptionOccurred);
    629       algorithm.ExecutionTimeChanged -= new EventHandler(ClonedAlgorithm_ExecutionTimeChanged);
    630       algorithm.Started -= new EventHandler(ClonedAlgorithm_Started);
    631       algorithm.Paused -= new EventHandler(ClonedAlgorithm_Paused);
    632       algorithm.Stopped -= new EventHandler(ClonedAlgorithm_Stopped);
    633     }
    634     private void ClonedAlgorithm_ExceptionOccurred(object sender, EventArgs<Exception> e) {
    635       OnExceptionOccurred(e.Value);
    636     }
    637     private void ClonedAlgorithm_ExecutionTimeChanged(object sender, EventArgs e) {
    638       OnExecutionTimeChanged();
    639     }
    640 
    641     private readonly object locker = new object();
    642     private void ClonedAlgorithm_Started(object sender, EventArgs e) {
    643       lock (locker) {
    644         IAlgorithm algorithm = sender as IAlgorithm;
    645         if (algorithm != null && !results.ContainsKey(algorithm.Name))
    646           results.Add(new Result(algorithm.Name, "Contains results for the specific fold.", algorithm.Results));
    647 
    648         if (startPending) {
    649           int startedAlgorithms = clonedAlgorithms.Count(alg => alg.ExecutionState == ExecutionState.Started);
    650           if (startedAlgorithms == NumberOfWorkers.Value ||
    651              clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Prepared))
    652             startPending = false;
    653 
    654           if (pausePending) PauseAllClonedAlgorithms();
    655           if (stopPending) StopAllClonedAlgorithms();
    656         }
    657       }
    658     }
    659 
    660     private void ClonedAlgorithm_Paused(object sender, EventArgs e) {
    661       lock (locker) {
    662         if (pausePending && clonedAlgorithms.All(alg => alg.ExecutionState != ExecutionState.Started))
    663           OnPaused();
    664       }
    665     }
    666 
    667     private void ClonedAlgorithm_Stopped(object sender, EventArgs e) {
    668       lock (locker) {
    669         if (!stopPending && ExecutionState == ExecutionState.Started) {
    670           IAlgorithm preparedAlgorithm = clonedAlgorithms.Where(alg => alg.ExecutionState == ExecutionState.Prepared ||
    671                                                                        alg.ExecutionState == ExecutionState.Paused).FirstOrDefault();
    672           if (preparedAlgorithm != null) preparedAlgorithm.Start();
    673         }
    674         if (ExecutionState != ExecutionState.Stopped) {
    675           if (clonedAlgorithms.All(alg => alg.ExecutionState == ExecutionState.Stopped))
    676             OnStopped();
    677           else if (stopPending &&
    678                    clonedAlgorithms.All(
    679                      alg => alg.ExecutionState == ExecutionState.Prepared || alg.ExecutionState == ExecutionState.Stopped))
    680             OnStopped();
    681         }
    682       }
    683     }
    684     #endregion
    685     #endregion
    686 
    687     #region event firing
    688     public event EventHandler ExecutionStateChanged;
    689     private void OnExecutionStateChanged() {
    690       EventHandler handler = ExecutionStateChanged;
    691       if (handler != null) handler(this, EventArgs.Empty);
    692     }
    693     public event EventHandler ExecutionTimeChanged;
    694     private void OnExecutionTimeChanged() {
    695       EventHandler handler = ExecutionTimeChanged;
    696       if (handler != null) handler(this, EventArgs.Empty);
    697     }
    698     public event EventHandler Prepared;
    699     private void OnPrepared() {
    700       ExecutionState = ExecutionState.Prepared;
    701       EventHandler handler = Prepared;
    702       if (handler != null) handler(this, EventArgs.Empty);
    703       OnExecutionTimeChanged();
    704     }
    705     public event EventHandler Started;
    706     private void OnStarted() {
    707       startPending = false;
    708       ExecutionState = ExecutionState.Started;
    709       EventHandler handler = Started;
    710       if (handler != null) handler(this, EventArgs.Empty);
    711     }
    712     public event EventHandler Paused;
    713     private void OnPaused() {
    714       pausePending = false;
    715       ExecutionState = ExecutionState.Paused;
    716       EventHandler handler = Paused;
    717       if (handler != null) handler(this, EventArgs.Empty);
    718     }
    719     public event EventHandler Stopped;
    720     private void OnStopped() {
    721       stopPending = false;
    722       Dictionary<string, IItem> collectedResults = new Dictionary<string, IItem>();
    723       AggregateResultValues(collectedResults);
    724       results.AddRange(collectedResults.Select(x => new Result(x.Key, x.Value)).Cast<IResult>().ToArray());
    725       runsCounter++;
    726       runs.Add(new Run(string.Format("{0} Run {1}", Name, runsCounter), this));
    727       ExecutionState = ExecutionState.Stopped;
    728       EventHandler handler = Stopped;
    729       if (handler != null) handler(this, EventArgs.Empty);
    730     }
    731     public event EventHandler<EventArgs<Exception>> ExceptionOccurred;
    732     private void OnExceptionOccurred(Exception exception) {
    733       EventHandler<EventArgs<Exception>> handler = ExceptionOccurred;
    734       if (handler != null) handler(this, new EventArgs<Exception>(exception));
    735     }
    736     public event EventHandler StoreAlgorithmInEachRunChanged;
    737     private void OnStoreAlgorithmInEachRunChanged() {
    738       EventHandler handler = StoreAlgorithmInEachRunChanged;
    739       if (handler != null) handler(this, EventArgs.Empty);
    740     }
    741     #endregion
    742225  }
    743226}
Note: See TracChangeset for help on using the changeset viewer.