元理系院生の新入社員がPythonとJavaで色々頑張るブログ

プログラミングや機械学習について調べた事を書いていきます

【Java】行列計算用のクラス

行列計算

信号処理や機械学習アルゴリズムには行列計算が多く使われています。
例えば多次元正規分布では分散共分散行列の逆行列計算を行ったり、ニューラルネットワークでは順伝搬計算を行う際にユニットの出力とユニット結線荷重の内積計算を行ったりします。
機械学習アルゴリズムを実装する際には、毎回1から行列計算処理を実装していたのですが、そろそろ専用のクラスを作ろうと思いました。


public class Matrix2D {
	private double[][] matrix;

	public Matrix2D(double[] vector){
		this.matrix = new double[vector.length][1];
		for(int i=0; i<vector.length; i++){
			this.matrix[i][0] = vector[i];
		}
	}

	public Matrix2D(double[][] vector){
		this.matrix = new double[vector.length][vector[0].length];
		for(int r=0; r<vector.length; r++){
			for(int c=0; c<vector[r].length; c++){
				this.matrix[r][c] = vector[r][c];
			}
		}
	}

	//行列の取得
	public double[][] getArrays(){
		return this.matrix;
	}

	//行数を取得
	public int getRow(){
		return this.matrix.length;
	}

	//列数を取得
	public int getCol(){
		return this.matrix[0].length;
	}

	//転置行列を取得
	public Matrix2D T(){
		double[][] t = new double[this.getCol()][this.getRow()];
		for(int r=0; r<t.length; r++){
			for(int c=0; c<t[r].length; c++){
				t[r][c] = this.matrix[c][r];
			}
		}
		return new Matrix2D(t);
	}

	private void changeValue(int row, int col, double a){
		this.matrix[row][col] = a;
	}

	private double getValue(int row, int col){
		return this.matrix[row][col];
	}

	//行列の内積
	public static Matrix2D dot(Matrix2D a, Matrix2D b){
		double[][] d = new double[a.getRow()][b.getCol()];
		for(int r=0; r<a.getRow(); r++){
			for(int c=0; c<b.getCol(); c++){
				double sum = 0;
				for(int k=0; k<b.getRow(); k++){
					sum += a.getValue(r, k) * b.getValue(k, c);
				}
				d[r][c] = sum;
			}
		}
		return new Matrix2D(d);
	}

	//行列の加算
	public static Matrix2D add(Matrix2D a, Matrix2D b){
		double[][] d = new double[a.getRow()][a.getCol()];
		for(int c=0; c<a.getRow(); c++){
			for(int r=0; r<a.getCol(); r++){
				d[r][c] = a.getValue(r, c) + b.getValue(r, c);
			}
		}
		return new Matrix2D(d);
	}

	//行列の減算
	public static Matrix2D sub(Matrix2D a, Matrix2D b){
		double[][] d = new double[a.getRow()][a.getCol()];
		for(int c=0; c<a.getRow(); c++){
			for(int r=0; r<a.getCol(); r++){
				d[r][c] = a.getValue(r, c) - b.getValue(r, c);
			}
		}
		return new Matrix2D(d);
	}
	
	//行列のルート計算
	public static Matrix2D sqrt(Matrix2D a){
		double[][] d = new double[a.getRow()][a.getCol()];
		for(int r=0; r<a.getRow(); r++){
			for(int c=0; c<a.getCol(); c++){
				d[r][c] = Math.sqrt(a.getValue(r, c));
			}
		}
		return new Matrix2D(d);
	}

	//行列のルート計算
	public static Matrix2D sqrt(Matrix2D a, double b){
		double[][] d = new double[a.getRow()][a.getCol()];
		for(int r=0; r<a.getRow(); r++){
			for(int c=0; c<a.getCol(); c++){
				d[r][c] = Math.pow(a.getValue(r, c), b);
			}
		}
		return new Matrix2D(d);
	}

	//行列の定数倍
	public static Matrix2D prod(double a, Matrix2D b){
		double[][] d = new double[b.getRow()][b.getCol()];
		for(int r=0; r<b.getRow(); r++){
			for(int c=0; c<b.getCol(); c++){
				d[r][c] = b.getValue(r, c) * a;
			}
		}
		return new Matrix2D(d);
	}

	//行列の積
	public static Matrix2D mult(Matrix2D a, Matrix2D b){
		if(a.getCol() == b.getRow()){
			double[][] d = new double[a.getRow()][b.getCol()];
			for(int r=0; r<a.getRow(); r++){
				for(int c=0; c<b.getCol(); c++){
					double sum = 0;
					for(int i=0; i<b.getRow(); i++){
						sum += a.getValue(r, i) * b.getValue(i, c);
					}
					d[r][c] = sum;
				}
			}
			return new Matrix2D(d);
		}
		else{
			return null;
		}
	}

	//行列の横方向への結合
	public static Matrix2D catHorizon(Matrix2D a, Matrix2D b){
		if(a.getRow() == b.getRow()){
			double[][] d = new double[a.getRow()][a.getCol() + b.getCol()];
			for(int r=0; r<a.getRow(); r++){
				for(int c=0; c<a.getCol(); c++){
					d[r][c] = a.getValue(r, c);
				}
				for(int c=0; c<b.getCol(); c++){
					d[r][a.getCol()+c] = b.getValue(r, c);
				}
			}
			return new Matrix2D(d);
		}
		return null;
	}

	//行列の縦方向への結合
	public static Matrix2D catVertical(Matrix2D a, Matrix2D b){
		if(a.getCol() == b.getCol()){
			double[][] d = new double[a.getRow() + b.getRow()][a.getCol()];
			for(int c=0; c<a.getCol(); c++){
				for(int r=0; r<a.getRow(); r++){
					d[r][c] = a.getValue(r, c);
				}
				for(int r=0; r<b.getRow(); r++){
					d[a.getRow()+r][c] = b.getValue(r, c);
				}
			}
			return new Matrix2D(d);
		}
		else return null;
	}

	private static Matrix2D pivodRow(int i, int j, Matrix2D a){
		double[][] b = new double[a.getRow()][a.getCol()];
		for(int r=0; r<a.getRow(); r++){
			for(int c=0; c<a.getCol(); c++){
				if(r == i)b[r][c] = a.getValue(j, c);
				else if(r == j)b[r][c] = a.getValue(i, c);
				else b[r][c] = a.getValue(r, c);
			}
		}
		return new Matrix2D(b);
	}

	//逆行列の取得
	public static Matrix2D inv(Matrix2D a){
		Matrix2D d = Matrix2D.catHorizon(a, Matrix2D.IdentityMatrix(a.getRow()));
		for(int r=0; r<a.getRow(); r++){
			if(a.getValue(r, r) >= 0)continue;
			for(int i=r+1; i<a.getRow(); i++){
				if(a.getValue(i, r) >= 0){
					d = Matrix2D.pivodRow(r, i, d);
					break;
				}
			}
		}

		for(int r=0; r<a.getCol(); r++){
			double A = 1.0 / d.getValue(r, r);
			for(int c=0; c<d.getCol(); c++){
				d.changeValue(r, c, A*d.getValue(r, c));
			}

			for(int i=r+1; i<a.getRow(); i++){
				double B = d.getValue(i, r);
				for(int c=0; c<d.getCol(); c++){
					d.changeValue(i, c, d.getValue(i, c)-B*d.getValue(r, c));
				}
			}
		}

		for(int r=0; r<a.getRow(); r++){
			for(int i=r+1; i<a.getCol(); i++){
				double B = d.getValue(r, i);
				for(int j=i; j<d.getCol(); j++){
					d.changeValue(r, j, d.getValue(r, j) - B*d.getValue(i, j));
				}
			}
		}
		double[][] e = new double[a.getRow()][a.getCol()];
		for(int r=0; r<a.getRow(); r++){
			for(int c=0; c<a.getCol(); c++){
				e[r][c] = d.getValue(r, c+a.getCol());
			}
		}
		return new Matrix2D(e);
	}
	
	//単位行列の取得
	public static Matrix2D IdentityMatrix(int N){
		double[][] a = new double[N][N];
		for(int r=0; r<N; r++){
			for(int c=0; c<N; c++){
				if(r == c)a[r][c] = 1.0;
			}
		}
		return new Matrix2D(a);
	}

	@Override
	public String toString(){
		StringBuilder sb = new StringBuilder();
		for(int r=0; r<getRow(); r++){
			sb.append("|");
			for(int c=0; c<getCol(); c++){
				if(matrix[r][c] < 0)sb.append(String.format("%.5f ", matrix[r][c]));
				else sb.append(String.format(" %.5f ", matrix[r][c]));
			}
			sb.append("|\n");
		}
		return sb.toString();
	}

このクラスは次のように使います

//1次元の配列から行列の作成
double[] _a = {1, 2, 3, 4};
Matrix2D a = new Matrix2D(_a);
System.out.println(a);
| 1.00000 |
| 2.00000 |
| 3.00000 |
| 4.00000 |
//2次元の配列から行列を作成
double[][] _b = {
	{1, 2, 3},
	{4, 5, 6},
	{7, 8, 9}
};
Matrix2D b = new Matrix2D(_b);
System.out.println(b);
| 1.00000  2.00000  3.00000 |
| 4.00000  5.00000  6.00000 |
| 7.00000  8.00000  9.00000 |
//行と列数の取得
System.out.println(a.getRow()+"\t"+b.getCol());
4	3
//転置行列
System.out.println(a.T());
System.out.println(b.T());
| 1.00000  2.00000  3.00000  4.00000 |

| 1.00000  4.00000  7.00000 |
| 2.00000  5.00000  8.00000 |
| 3.00000  6.00000  9.00000 |
//行列の和と差
System.out.println(Matrix2D.add(b, b));
System.out.println(Matrix2D.sub(b, b));
| 2.00000  4.00000  6.00000 |
| 8.00000  10.00000  12.00000 |
| 14.00000  16.00000  18.00000 |

| 0.00000  0.00000  0.00000 |
| 0.00000  0.00000  0.00000 |
| 0.00000  0.00000  0.00000 |
//行列の低数倍
System.out.println(Matrix2D.prod(100, b));
| 100.00000  200.00000  300.00000 |
| 400.00000  500.00000  600.00000 |
| 700.00000  800.00000  900.00000 |
//行列の内積
double[][] _c = {
	{1, 2, 3, 4},
	{5, 6, 7, 8},
	{9, 10, 11, 12},
};
double[][] _d = {
	{1, 2},
	{3, 4},
	{5, 6},
	{7, 8},
};
System.out.println(Matrix2D.dot(new Matrix2D(_c), new Matrix2D(_d)));
| 50.00000  60.00000 |
| 114.00000  140.00000 |
| 178.00000  220.00000 |
//行列の結合(横方向)
System.out.println(Matrix2D.catHorizon(b , b));
//行列の結合(縦方向)
System.out.println(Matrix2D.catVertical(b, b));
| 1.00000  2.00000  3.00000  1.00000  2.00000  3.00000 |
| 4.00000  5.00000  6.00000  4.00000  5.00000  6.00000 |
| 7.00000  8.00000  9.00000  7.00000  8.00000  9.00000 |

| 1.00000  2.00000  3.00000 |
| 4.00000  5.00000  6.00000 |
| 7.00000  8.00000  9.00000 |
| 1.00000  2.00000  3.00000 |
| 4.00000  5.00000  6.00000 |
| 7.00000  8.00000  9.00000 |
//N×Nの単位行列
System.out.println(Matrix2D.IdentityMatrix(3));
| 1.00000  0.00000  0.00000 |
| 0.00000  1.00000  0.00000 |
| 0.00000  0.00000  1.00000 |
//逆行列の取得
double[][] _e = {
	{2, 3, 4},
	{5, 6, 7},
	{8, 9, 0},
};
System.out.println(Matrix2D.inv(new Matrix2D(_e)));
|-2.10000  1.20000 -0.10000 |
| 1.86667 -1.06667  0.20000 |
|-0.10000  0.20000 -0.10000 |