Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.Instances.DataAnalysis/3.3/Classification/TimeSeries/TimeSeriesInstanceProvider.cs @ 18242

Last change on this file since 18242 was 17448, checked in by pfleck, 5 years ago

#3040 Replaced own Vector with MathNet.Numerics Vector.

File size: 8.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Collections.ObjectModel;
26using System.Diagnostics;
27using System.Globalization;
28using System.IO;
29using System.IO.Compression;
30using System.Linq;
31using HeuristicLab.Problems.DataAnalysis;
32
33using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
34
35namespace HeuristicLab.Problems.Instances.DataAnalysis {
36  public abstract class TimeSeriesInstanceProvider : ResourceClassificationInstanceProvider {
37    //public override string Name {
38    //  get { return "TimeSeries (Univariate) Problems"; }
39    //}
40    public override string Description {
41      get { return "UEA & UCR TimeSeries Problems"; }
42    }
43    public override Uri WebLink {
44      get { return new Uri("http://www.timeseriesclassification.com/"); }
45    }
46    public override string ReferencePublication {
47      get { return "Anthony Bagnall, Jason Lines, William Vickers and Eamonn Keogh, The UEA & UCR Time Series Classification Repository, www.timeseriesclassification.com"; }
48    }
49
50    public override IClassificationProblemData LoadData(IDataDescriptor id) {
51      var descriptor = (TimeSeriesDataDescriptor)id;
52      using (var instancesZipFile = OpenZipArchive()) {
53        var trainingEntry = instancesZipFile.GetEntry(descriptor.TrainingEntryName);
54        var testEntry = instancesZipFile.GetEntry(descriptor.TestEntryName);
55
56        if (trainingEntry == null || testEntry == null) {
57          throw new InvalidOperationException("The training or test entry could not be found in the archive.");
58        }
59
60        using (var trainingReader = new StreamReader(trainingEntry.Open()))
61        using (var testReader = new StreamReader(testEntry.Open())) {
62          ParseMetadata(trainingReader, out var inputVariables, out string targetVariable, out var classLabels);
63          ParseMetadata(testReader, out _, out _, out _); // ignore outputs
64
65          // Read data
66          var inputsData = new List<DoubleVector>[inputVariables.Count];
67          for (int i = 0; i < inputsData.Length; i++) inputsData[i] = new List<DoubleVector>();
68          bool numericTarget = classLabels.All(label => !double.IsNaN(ParseNumber(label)));
69          IList targetData = numericTarget ? new List<double>() : new List<string>() as IList;
70          ReadData(trainingReader, inputsData, targetData, out int numTrainingRows);
71          ReadData(testReader, inputsData, targetData, out int numTestRows);
72
73          // Translate class values to numeric values
74          if (targetData is List<string> stringTargetData) {
75            var labelTranslation = classLabels
76              .Select((x, i) => new { Label = x, i })
77              .ToDictionary(x => x.Label, x => (double)x.i);
78            targetData = stringTargetData.Select(label => labelTranslation[label]).ToList();
79          }
80
81          // Build dataset
82          var dataset = new Dataset(
83            inputVariables.Concat(new[] { targetVariable }),
84            inputsData.Concat(new[] { targetData })
85          );
86          Debug.Assert(dataset.Rows == numTrainingRows + numTestRows);
87          Debug.Assert(dataset.Columns == inputVariables.Count + 1);
88
89          // Build problem data
90          var problemData = new ClassificationProblemData(dataset, inputVariables, targetVariable) {
91            Name = descriptor.Name
92          };
93          problemData.TrainingPartition.Start = 0;
94          problemData.TrainingPartition.End = numTrainingRows;
95          problemData.TestPartition.Start = numTrainingRows;
96          problemData.TestPartition.End = numTrainingRows + numTestRows;
97
98          return problemData;
99        }
100      }
101    }
102
103    private static void ParseMetadata(StreamReader reader, out List<string> inputVariables, out string targetVariable, out List<string> classLabels) {
104      int nrOfInputs = 0;
105      IEnumerable<string> labels = null;
106      bool dataStart = false;
107
108      while (!reader.EndOfStream && !dataStart) {
109        var line = reader.ReadLine();
110        if (line.StartsWith("#")) {
111          // Comment
112        } else if (line.StartsWith("@")) {
113          var splits = line.Split(' ');
114          var type = splits.First();
115          var arguments = splits.Skip(1).ToList();
116          switch (type) {
117            case "@univariate":
118              bool univariate = bool.Parse(arguments[0]);
119              if (univariate)
120                nrOfInputs = 1;
121              break;
122            case "@dimensions":
123              int dimensions = int.Parse(arguments[0]);
124              nrOfInputs = dimensions;
125              break;
126            case "@classLabel":
127              bool containLabels = bool.Parse(arguments[0]);
128              if (containLabels)
129                labels = arguments.Skip(1);
130              break;
131            case "@data":
132              dataStart = true;
133              break;
134          }
135        } else {
136          throw new InvalidOperationException("A data section already occurred within metadata section.");
137        }
138      }
139
140      int digits = Math.Max((int)Math.Log10(nrOfInputs - 1) + 1, 1);
141      inputVariables = Enumerable.Range(0, nrOfInputs)
142        .Select(i => "X" + i.ToString("D" + digits))
143        .ToList();
144
145      targetVariable = "Y";
146
147      classLabels = labels.ToList();
148    }
149
150    private static void ReadData(StreamReader reader, List<DoubleVector>[] inputsData, IList targetData, out int count) {
151      var numericTargetData = targetData as List<double>;
152      var stringTargetData = targetData as List<string>;
153
154      count = 0;
155      while (!reader.EndOfStream) {
156        var line = reader.ReadLine();
157        var variables = line.Split(':');
158
159        // parse all except last, which is the non-vector target
160        for (int i = 0; i < variables.Length - 1; i++) {
161          var variable = variables[i];
162          var numbers = variable
163            .Split(',')
164            .Select(ParseNumber);
165          inputsData[i].Add(DoubleVector.Build.DenseOfEnumerable(numbers));
166        }
167
168        var target = variables[variables.Length - 1];
169        if (numericTargetData != null) numericTargetData.Add(ParseNumber(target));
170        else if (stringTargetData != null) stringTargetData.Add(target);
171        else throw new InvalidOperationException("Target must either be numeric or a string.");
172
173        count++;
174      }
175    }
176
177    private static double ParseNumber(string number) {
178      return
179        double.TryParse(number, NumberStyles.Float, CultureInfo.InvariantCulture, out double parsed)
180          ? parsed
181          : double.NaN;
182    }
183
184    public override IEnumerable<IDataDescriptor> GetDataDescriptors() {
185      using (var instancesZipFile = OpenZipArchive()) {
186        var instances = GroupEntriesByInstance(instancesZipFile.Entries);
187        var descriptors = instances.Select(instance => CreateDescriptor(instance.Key, instance.Value));
188
189        return descriptors.ToList();
190      }
191    }
192
193    private ZipArchive OpenZipArchive() {
194      var instanceArchiveName = Path.Combine("Classification", "Data", FileName + ".zip");
195      var stream = new FileStream(instanceArchiveName, FileMode.Open, FileAccess.Read, FileShare.Read);
196      return new ZipArchive(stream, ZipArchiveMode.Read);
197    }
198
199    private static IDictionary<string, List<ZipArchiveEntry>> GroupEntriesByInstance(ReadOnlyCollection<ZipArchiveEntry> entries) {
200      var topLevelEntries = entries.Where(entry => string.IsNullOrEmpty(entry.Name)).ToList();
201
202      return topLevelEntries.ToDictionary(
203        entry => Path.GetDirectoryName(entry.FullName),
204        entry => entries.Except(topLevelEntries).Where(subEntry => subEntry.FullName.StartsWith(entry.FullName)).ToList());
205    }
206
207    private static TimeSeriesDataDescriptor CreateDescriptor(string name, List<ZipArchiveEntry> subEntries) {
208      var trainingEntry = subEntries.Single(entry => entry.Name.EndsWith("_TRAIN.ts"));
209      var testEntry = subEntries.Single(entry => entry.Name.EndsWith("_TEST.ts"));
210      return new TimeSeriesDataDescriptor(name, trainingEntry.FullName, testEntry.FullName);
211    }
212  }
213}
Note: See TracBrowser for help on using the repository browser.