【Java】行列計算用のクラス
行列計算
信号処理や機械学習のアルゴリズムには行列計算が多く使われています。
例えば多次元正規分布では分散共分散行列の逆行列計算を行ったり、ニューラルネットワークでは順伝搬計算を行う際にユニットの出力とユニット結線荷重の内積計算を行ったりします。
機械学習のアルゴリズムを実装する際には、毎回1から行列計算処理を実装していたのですが、そろそろ専用のクラスを作ろうと思いました。
public class Matrix2D { private double[][] matrix; public Matrix2D(double[] vector){ this.matrix = new double[vector.length][1]; for(int i=0; i<vector.length; i++){ this.matrix[i][0] = vector[i]; } } public Matrix2D(double[][] vector){ this.matrix = new double[vector.length][vector[0].length]; for(int r=0; r<vector.length; r++){ for(int c=0; c<vector[r].length; c++){ this.matrix[r][c] = vector[r][c]; } } } //行列の取得 public double[][] getArrays(){ return this.matrix; } //行数を取得 public int getRow(){ return this.matrix.length; } //列数を取得 public int getCol(){ return this.matrix[0].length; } //転置行列を取得 public Matrix2D T(){ double[][] t = new double[this.getCol()][this.getRow()]; for(int r=0; r<t.length; r++){ for(int c=0; c<t[r].length; c++){ t[r][c] = this.matrix[c][r]; } } return new Matrix2D(t); } private void changeValue(int row, int col, double a){ this.matrix[row][col] = a; } private double getValue(int row, int col){ return this.matrix[row][col]; } //行列の内積 public static Matrix2D dot(Matrix2D a, Matrix2D b){ double[][] d = new double[a.getRow()][b.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<b.getCol(); c++){ double sum = 0; for(int k=0; k<b.getRow(); k++){ sum += a.getValue(r, k) * b.getValue(k, c); } d[r][c] = sum; } } return new Matrix2D(d); } //行列の加算 public static Matrix2D add(Matrix2D a, Matrix2D b){ double[][] d = new double[a.getRow()][a.getCol()]; for(int c=0; c<a.getRow(); c++){ for(int r=0; r<a.getCol(); r++){ d[r][c] = a.getValue(r, c) + b.getValue(r, c); } } return new Matrix2D(d); } //行列の減算 public static Matrix2D sub(Matrix2D a, Matrix2D b){ double[][] d = new double[a.getRow()][a.getCol()]; for(int c=0; c<a.getRow(); c++){ for(int r=0; r<a.getCol(); r++){ d[r][c] = a.getValue(r, c) - b.getValue(r, c); } } return new Matrix2D(d); } //行列のルート計算 public static Matrix2D sqrt(Matrix2D a){ double[][] d = new double[a.getRow()][a.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<a.getCol(); c++){ d[r][c] = Math.sqrt(a.getValue(r, c)); } } return new Matrix2D(d); } //行列のルート計算 public static Matrix2D sqrt(Matrix2D a, double b){ double[][] d = new double[a.getRow()][a.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<a.getCol(); c++){ d[r][c] = Math.pow(a.getValue(r, c), b); } } return new Matrix2D(d); } //行列の定数倍 public static Matrix2D prod(double a, Matrix2D b){ double[][] d = new double[b.getRow()][b.getCol()]; for(int r=0; r<b.getRow(); r++){ for(int c=0; c<b.getCol(); c++){ d[r][c] = b.getValue(r, c) * a; } } return new Matrix2D(d); } //行列の積 public static Matrix2D mult(Matrix2D a, Matrix2D b){ if(a.getCol() == b.getRow()){ double[][] d = new double[a.getRow()][b.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<b.getCol(); c++){ double sum = 0; for(int i=0; i<b.getRow(); i++){ sum += a.getValue(r, i) * b.getValue(i, c); } d[r][c] = sum; } } return new Matrix2D(d); } else{ return null; } } //行列の横方向への結合 public static Matrix2D catHorizon(Matrix2D a, Matrix2D b){ if(a.getRow() == b.getRow()){ double[][] d = new double[a.getRow()][a.getCol() + b.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<a.getCol(); c++){ d[r][c] = a.getValue(r, c); } for(int c=0; c<b.getCol(); c++){ d[r][a.getCol()+c] = b.getValue(r, c); } } return new Matrix2D(d); } return null; } //行列の縦方向への結合 public static Matrix2D catVertical(Matrix2D a, Matrix2D b){ if(a.getCol() == b.getCol()){ double[][] d = new double[a.getRow() + b.getRow()][a.getCol()]; for(int c=0; c<a.getCol(); c++){ for(int r=0; r<a.getRow(); r++){ d[r][c] = a.getValue(r, c); } for(int r=0; r<b.getRow(); r++){ d[a.getRow()+r][c] = b.getValue(r, c); } } return new Matrix2D(d); } else return null; } private static Matrix2D pivodRow(int i, int j, Matrix2D a){ double[][] b = new double[a.getRow()][a.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<a.getCol(); c++){ if(r == i)b[r][c] = a.getValue(j, c); else if(r == j)b[r][c] = a.getValue(i, c); else b[r][c] = a.getValue(r, c); } } return new Matrix2D(b); } //逆行列の取得 public static Matrix2D inv(Matrix2D a){ Matrix2D d = Matrix2D.catHorizon(a, Matrix2D.IdentityMatrix(a.getRow())); for(int r=0; r<a.getRow(); r++){ if(a.getValue(r, r) >= 0)continue; for(int i=r+1; i<a.getRow(); i++){ if(a.getValue(i, r) >= 0){ d = Matrix2D.pivodRow(r, i, d); break; } } } for(int r=0; r<a.getCol(); r++){ double A = 1.0 / d.getValue(r, r); for(int c=0; c<d.getCol(); c++){ d.changeValue(r, c, A*d.getValue(r, c)); } for(int i=r+1; i<a.getRow(); i++){ double B = d.getValue(i, r); for(int c=0; c<d.getCol(); c++){ d.changeValue(i, c, d.getValue(i, c)-B*d.getValue(r, c)); } } } for(int r=0; r<a.getRow(); r++){ for(int i=r+1; i<a.getCol(); i++){ double B = d.getValue(r, i); for(int j=i; j<d.getCol(); j++){ d.changeValue(r, j, d.getValue(r, j) - B*d.getValue(i, j)); } } } double[][] e = new double[a.getRow()][a.getCol()]; for(int r=0; r<a.getRow(); r++){ for(int c=0; c<a.getCol(); c++){ e[r][c] = d.getValue(r, c+a.getCol()); } } return new Matrix2D(e); } //単位行列の取得 public static Matrix2D IdentityMatrix(int N){ double[][] a = new double[N][N]; for(int r=0; r<N; r++){ for(int c=0; c<N; c++){ if(r == c)a[r][c] = 1.0; } } return new Matrix2D(a); } @Override public String toString(){ StringBuilder sb = new StringBuilder(); for(int r=0; r<getRow(); r++){ sb.append("|"); for(int c=0; c<getCol(); c++){ if(matrix[r][c] < 0)sb.append(String.format("%.5f ", matrix[r][c])); else sb.append(String.format(" %.5f ", matrix[r][c])); } sb.append("|\n"); } return sb.toString(); }
このクラスは次のように使います
//1次元の配列から行列の作成 double[] _a = {1, 2, 3, 4}; Matrix2D a = new Matrix2D(_a); System.out.println(a);
| 1.00000 | | 2.00000 | | 3.00000 | | 4.00000 |
//2次元の配列から行列を作成 double[][] _b = { {1, 2, 3}, {4, 5, 6}, {7, 8, 9} }; Matrix2D b = new Matrix2D(_b); System.out.println(b);
| 1.00000 2.00000 3.00000 | | 4.00000 5.00000 6.00000 | | 7.00000 8.00000 9.00000 |
//行と列数の取得 System.out.println(a.getRow()+"\t"+b.getCol());
4 3
//転置行列
System.out.println(a.T());
System.out.println(b.T());
| 1.00000 2.00000 3.00000 4.00000 | | 1.00000 4.00000 7.00000 | | 2.00000 5.00000 8.00000 | | 3.00000 6.00000 9.00000 |
//行列の和と差
System.out.println(Matrix2D.add(b, b));
System.out.println(Matrix2D.sub(b, b));
| 2.00000 4.00000 6.00000 | | 8.00000 10.00000 12.00000 | | 14.00000 16.00000 18.00000 | | 0.00000 0.00000 0.00000 | | 0.00000 0.00000 0.00000 | | 0.00000 0.00000 0.00000 |
//行列の低数倍 System.out.println(Matrix2D.prod(100, b));
| 100.00000 200.00000 300.00000 | | 400.00000 500.00000 600.00000 | | 700.00000 800.00000 900.00000 |
//行列の内積 double[][] _c = { {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, }; double[][] _d = { {1, 2}, {3, 4}, {5, 6}, {7, 8}, }; System.out.println(Matrix2D.dot(new Matrix2D(_c), new Matrix2D(_d)));
| 50.00000 60.00000 | | 114.00000 140.00000 | | 178.00000 220.00000 |
//行列の結合(横方向) System.out.println(Matrix2D.catHorizon(b , b)); //行列の結合(縦方向) System.out.println(Matrix2D.catVertical(b, b));
| 1.00000 2.00000 3.00000 1.00000 2.00000 3.00000 | | 4.00000 5.00000 6.00000 4.00000 5.00000 6.00000 | | 7.00000 8.00000 9.00000 7.00000 8.00000 9.00000 | | 1.00000 2.00000 3.00000 | | 4.00000 5.00000 6.00000 | | 7.00000 8.00000 9.00000 | | 1.00000 2.00000 3.00000 | | 4.00000 5.00000 6.00000 | | 7.00000 8.00000 9.00000 |
//N×Nの単位行列 System.out.println(Matrix2D.IdentityMatrix(3));
| 1.00000 0.00000 0.00000 | | 0.00000 1.00000 0.00000 | | 0.00000 0.00000 1.00000 |
//逆行列の取得 double[][] _e = { {2, 3, 4}, {5, 6, 7}, {8, 9, 0}, }; System.out.println(Matrix2D.inv(new Matrix2D(_e)));
|-2.10000 1.20000 -0.10000 | | 1.86667 -1.06667 0.20000 | |-0.10000 0.20000 -0.10000 |