#region License Information
/* HeuristicLab
* Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.PluginInfrastructure;
namespace HeuristicLab.Algorithms.DataAnalysis {
[NonDiscoverableType]
internal class Matrix : IEnumerable, IDeepCloneable {
// this type is immutable
private readonly IEnumerable values;
public readonly int Rows;
public readonly int Columns;
protected Matrix(Matrix original, Cloner cloner) {
this.values = original.values.ToArray();
this.Rows = original.Rows;
this.Columns = original.Columns;
cloner.RegisterClonedObject(original, this);
}
public Matrix(IEnumerable vector) {
this.values = vector;
Rows = 1;
Columns = vector.Count();
}
public Matrix(IEnumerable vector, int length) {
this.values = vector;
Rows = 1;
Columns = length;
}
public Matrix(double[,] matrix) {
this.values = GetOnlineValues(matrix);
Rows = matrix.GetLength(0);
Columns = matrix.GetLength(1);
}
public Matrix(IEnumerable matrix, int rows, int columns) {
this.values = matrix;
Rows = rows;
Columns = columns;
}
public object Clone() {
return Clone(new Cloner());
}
public IDeepCloneable Clone(Cloner cloner) {
return new Matrix(this, cloner);
}
public Matrix Transpose() {
var result = new Matrix(Transpose(values, Columns, Rows), Columns, Rows);
return result;
}
private IEnumerable Transpose(IEnumerable values, int rows, int columns) {
// vectors don't need to be transposed
if (rows == 1 || columns == 1) {
foreach (var v in values) yield return v;
yield break;
}
int skip = 0;
var iter = values.GetEnumerator();
if (!iter.MoveNext()) yield break;
while (skip < rows) {
for (int i = 0; i < skip; i++) iter.MoveNext();
yield return iter.Current;
for (int j = 0; j < columns - 1; j++) {
for (int i = 0; i < rows; i++) iter.MoveNext();
yield return iter.Current;
}
skip++;
if (skip < rows) {
iter = values.GetEnumerator();
iter.MoveNext();
}
}
}
public Matrix Add(Matrix other) {
return new Matrix(AddOnline(other), Rows, Columns);
}
public void AddTo(double[,] matrix) {
if (Rows != matrix.GetLength(0) || Columns != matrix.GetLength(1)) throw new ArgumentException("unequal size", "matrix");
var iter = values.GetEnumerator();
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++) {
iter.MoveNext();
matrix[i, j] += iter.Current;
}
}
public Matrix Subtract(Matrix other) {
return new Matrix(SubtractOnline(other), Rows, Columns);
}
public Matrix Multiply(Matrix other) {
return new Matrix(MultiplyOnline(other), Rows, other.Columns);
}
public Matrix Multiply(double value) {
return new Matrix(values.Select(x => x * value), Rows, Columns);
}
public double VectorLength() {
return Math.Sqrt(SquaredVectorLength());
}
public double SquaredVectorLength() {
if (Rows != 1) throw new ArgumentException("Length only works on vectors.");
return values.Sum(x => x * x);
}
public Matrix OuterProduct(Matrix other) {
if (Rows != 1 || other.Rows != 1) throw new ArgumentException("OuterProduct can only be applied to vectors.");
return Transpose().Multiply(other);
}
public Matrix Negate() {
return new Matrix(values.Select(x => -x), Rows, Columns);
}
public Matrix Apply() {
return new Matrix(values.ToArray(), Rows, Columns);
}
public IEnumerator GetEnumerator() { return values.GetEnumerator(); }
IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); }
private IEnumerable AddOnline(Matrix other) {
if (Rows != other.Rows || Columns != other.Columns) throw new ArgumentException("Number of rows and columns are not equal.");
var meIter = values.GetEnumerator();
var otherIter = other.GetEnumerator();
if (!meIter.MoveNext()) yield break;
if (!otherIter.MoveNext()) yield break;
for (int i = 0; i < Rows * Columns; i++) {
yield return meIter.Current + otherIter.Current;
meIter.MoveNext();
otherIter.MoveNext();
}
}
private IEnumerable SubtractOnline(Matrix other) {
if (Rows != other.Rows || Columns != other.Columns) throw new ArgumentException("Number of rows and columns are not equal.");
var meIter = values.GetEnumerator();
var otherIter = other.GetEnumerator();
if (!meIter.MoveNext()) yield break;
if (!otherIter.MoveNext()) yield break;
for (int i = 0; i < Rows * Columns; i++) {
yield return meIter.Current - otherIter.Current;
meIter.MoveNext();
otherIter.MoveNext();
}
}
private IEnumerable MultiplyOnline(Matrix other) {
if (Columns != other.Rows) throw new ArgumentException("Number of rows and columns are not equal.");
var meIter = values.GetEnumerator();
var otherByColumn = other.Transpose();
var otherIter = otherByColumn.GetEnumerator();
if (!meIter.MoveNext()) yield break;
if (!otherIter.MoveNext()) yield break;
for (int r = 0; r < Rows; r++) {
var row = new double[Columns];
for (int x = 0; x < Columns; x++) {
row[x] = meIter.Current;
meIter.MoveNext();
}
for (int c = 0; c < other.Columns; c++) {
var sum = 0.0;
for (int y = 0; y < other.Rows; y++) {
sum += row[y] * otherIter.Current;
otherIter.MoveNext();
}
yield return sum;
}
otherIter = otherByColumn.GetEnumerator();
otherIter.MoveNext();
}
}
private IEnumerable GetOnlineValues(double[,] matrix) {
for (int i = 0; i < matrix.GetLength(0); i++)
for (int j = 0; j < matrix.GetLength(1); j++) {
yield return matrix[i, j];
}
}
}
}