Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Operator Architecture Refactoring/LibSVM/GaussianTransform.cs @ 2443

Last change on this file since 2443 was 1819, checked in by mkommend, 16 years ago

created new project for LibSVM source files (ticket #619)

File size: 7.5 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22using System.IO;
23
24namespace SVM
25{
26    /// <remarks>
27    /// A transform which learns the mean and variance of a sample set and uses these to transform new data
28    /// so that it has zero mean and unit variance.
29    /// </remarks>
30    public class GaussianTransform : IRangeTransform
31    {
32        private List<Node[]> _samples;
33        private int _maxIndex;
34
35        private double[] _means;
36        private double[] _stddevs;
37
38        /// <summary>
39        /// Constructor.
40        /// </summary>
41        /// <param name="maxIndex">The maximum index of the vectors to be transformed</param>
42        public GaussianTransform(int maxIndex)
43        {
44            _samples = new List<Node[]>();
45        }
46        private GaussianTransform(double[] means, double[] stddevs, int maxIndex)
47        {
48            _means = means;
49            _stddevs = stddevs;
50            _maxIndex = maxIndex;
51        }
52
53        /// <summary>
54        /// Adds a sample to the data.  No computation is performed.  The maximum index of the
55        /// sample must be less than MaxIndex.
56        /// </summary>
57        /// <param name="sample">The sample to add</param>
58        public void Add(Node[] sample)
59        {
60            _samples.Add(sample);
61        }
62
63        /// <summary>
64        /// Computes the statistics for the samples which have been obtained so far.
65        /// </summary>
66        public void ComputeStatistics()
67        {
68            int[] counts = new int[_maxIndex];
69            _means = new double[_maxIndex];
70            foreach(Node[] sample in _samples)
71            {
72                for (int i = 0; i < sample.Length; i++)
73                {
74                    _means[sample[i].Index] += sample[i].Value;
75                    counts[sample[i].Index]++;
76                }
77            }
78            for (int i = 0; i < _maxIndex; i++)
79            {
80                if (counts[i] == 0)
81                    counts[i] = 2;
82                _means[i] /= counts[i];
83            }
84
85            _stddevs = new double[_maxIndex];
86            foreach(Node[] sample in _samples)
87            {
88                for (int i = 0; i < sample.Length; i++)
89                {
90                    double diff = sample[i].Value - _means[sample[i].Index];
91                    _stddevs[sample[i].Index] += diff * diff;
92                }
93            }
94            for (int i = 0; i < _maxIndex; i++)
95            {
96                if (_stddevs[i] == 0)
97                    continue;
98                _stddevs[i] /= (counts[i]-1);
99                _stddevs[i] = Math.Sqrt(_stddevs[i]);
100            }
101        }
102
103        /// <summary>
104        /// Saves the transform to the disk.  The samples are not stored, only the
105        /// statistics.
106        /// </summary>
107        /// <param name="stream">The destination stream</param>
108        /// <param name="transform">The transform</param>
109        public static void Write(Stream stream, GaussianTransform transform)
110        {
111            StreamWriter output = new StreamWriter(stream);
112            output.WriteLine(transform._maxIndex);
113            for (int i = 0; i < transform._maxIndex; i++)
114                output.WriteLine("{0} {1}", transform._means[i], transform._stddevs[i]);
115            output.Flush();
116        }
117
118        /// <summary>
119        /// Reads a GaussianTransform from the provided stream.
120        /// </summary>
121        /// <param name="stream">The source stream</param>
122        /// <returns>The transform</returns>
123        public static GaussianTransform Read(Stream stream)
124        {
125            StreamReader input = new StreamReader(stream);
126            int length = int.Parse(input.ReadLine());
127            double[] means = new double[length];
128            double[] stddevs = new double[length];
129            for (int i = 0; i < length; i++)
130            {
131                string[] parts = input.ReadLine().Split();
132                means[i] = double.Parse(parts[0]);
133                stddevs[i] = double.Parse(parts[1]);
134            }
135            return new GaussianTransform(means, stddevs, length);
136        }
137
138        /// <summary>
139        /// Saves the transform to the disk.  The samples are not stored, only the
140        /// statistics.
141        /// </summary>
142        /// <param name="filename">The destination filename</param>
143        /// <param name="transform">The transform</param>
144        public static void Write(string filename, GaussianTransform transform)
145        {
146            FileStream output = File.Open(filename, FileMode.Create);
147            try
148            {
149                Write(output, transform);
150            }
151            finally
152            {
153                output.Close();
154            }
155        }
156
157        /// <summary>
158        /// Reads a GaussianTransform from the provided stream.
159        /// </summary>
160        /// <param name="filename">The source filename</param>
161        /// <returns>The transform</returns>
162        public static GaussianTransform Read(string filename)
163        {
164            FileStream input = File.Open(filename, FileMode.Open);
165            try
166            {
167                return Read(input);
168            }
169            finally
170            {
171                input.Close();
172            }
173        }
174
175        #region IRangeTransform Members
176
177        /// <summary>
178        /// Transform the input value using the transform stored for the provided index.
179        /// <see cref="ComputeStatistics"/> must be called first, or the transform must
180        /// have been read from the disk.
181        /// </summary>
182        /// <param name="input">Input value</param>
183        /// <param name="index">Index of the transform to use</param>
184        /// <returns>The transformed value</returns>
185        public double Transform(double input, int index)
186        {
187            if (_stddevs[index] == 0)
188                return 0;
189            double diff = input - _means[index];
190            diff /= _stddevs[index];
191            return diff;
192        }
193        /// <summary>
194        /// Transforms the input array.  <see cref="ComputeStatistics"/> must be called
195        /// first, or the transform must have been read from the disk.
196        /// </summary>
197        /// <param name="input">The array to transform</param>
198        /// <returns>The transformed array</returns>
199        public Node[] Transform(Node[] input)
200        {
201            Node[] output = new Node[input.Length];
202            for (int i = 0; i < output.Length; i++)
203            {
204                int index = input[i].Index;
205                double value = input[i].Value;
206                output[i] = new Node(index, Transform(value, index));
207            }
208            return output;
209        }
210
211        #endregion
212    }
213}
Note: See TracBrowser for help on using the repository browser.