- 参考:《机器学习算法框架实战:Java和Python实现》
- python实现主要是调用 NumPy 库做的;java实现基本没有调库
1. 说明
1.1 程序组织
1.2 数据结构
1.2.1 python实现
- 直接使用 NumPy 提供的
ndarray 数组作为矩阵即可import numpy as np
def test_mat_basic():
mat1 = np.array([[1, 2, 3], [4, 5, 6]])
print(mat1)
mat2 = np.zeros((3, 2))
print(mat2)
mat3 = np.ones((3, 2))
print(mat3)
mat4 = np.random.rand(3, 2)
print(mat4)
if __name__ == "__main__":
test_mat_basic()
- 更多构造方法如下
方法名 | 描述 |
---|
array | 将输入数据(列表、元组、数组及其他序列)转换为ndarray ,如不显式指明数据类型,将自动推断;默认复制所有输入数据 | asarray | 将输入转换为ndarray ,但若输入已经是ndarray 则不再复制 | arange | python内置range 函数的ndarray 版本,返回一个ndarray | ones | 根据给定形状和数据类型生成全1 数组 | ones_like | 根据给定的数组生成一个形状一样的全1 数组 | zeros | 根据给定形状和数据类型生成全0 数组 | zeros_like | 根据给定的数组生成一个形状一样的全0 数组 | empty | 根据给定形状生成一个没有初始化数值的空数组(通常是0,但也可能是一些未初始化的垃圾数值) | empty_like | 根据给定的数组生成一个形状一样但没有初始化数值的空数组 | full | 根据给定形状和数据类型生成指定数值的数组 | full_like | 根据给定的数组生成一个形状一样但内容是指定数值的数组 | eye,identity | 生成一个 NxN 特征矩阵(对角线位置都是1,其余位置为0) |
1.2.2 java实现
- java中,矩阵的数据结构设计如下
package LinearAlgebra;
import java.math.BigDecimal;
public class Matrix {
private BigDecimal[][] mat;
private int rowNum;
private int colNum;
public Matrix(int rowNum, int colNum) {
this.rowNum = rowNum;
this.colNum = colNum;
mat = new BigDecimal[rowNum][colNum];
initializeMatrix();
}
private void initializeMatrix(){
for (int i = 0; i < rowNum; i++) {
for (int j = 0; j < colNum; j++) {
mat[i][j] = new BigDecimal(0.0);
}
}
}
public void setValue(int x1,int x2,double value){
mat[x1][x2] = new BigDecimal(value);
}
public void setValue(int x1,int x2,BigDecimal value){
mat[x1][x2] = value;
}
public void setValue(BigDecimal[][] matrix){
for (int i = 0; i < rowNum; i++) {
for (int j = 0; j < colNum; j++) {
mat[i][j] = matrix[i][j];
}
}
}
public void setValue(double[][] matrix){
for (int i = 0; i < rowNum; i++) {
for (int j = 0; j < colNum; j++) {
mat[i][j] = new BigDecimal(matrix[i][j]);
}
}
}
public BigDecimal getValue(int x1,int x2){
return mat[x1][x2];
}
public int getRowNum() {
return rowNum;
}
public int getColNum() {
return colNum;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rowNum; i++) {
for (int j = 0; j < colNum; j++) {
sb.append(String.format("%15f",mat[i][j].doubleValue()));
}
sb.append('\n');
}
return sb.toString();
}
}
- 矩阵运算类示意如下
package LinearAlgebra;
import java.math.BigDecimal;
public class AlgebraUtil {
public final static Matrix add(Matrix a,Matrix b){
}
}
2. 矩阵基本操作
2.1 基本运算(加、减、叉乘、点乘、转置)
2.1.1 python实现
- python中,矩阵相当于2维的 ndarray 数组,可以直接使用numpy库方法实现矩阵的全部基本运算
import numpy as np
m1 = np.array([[1,2,3],
[4,5,6],
[7,8,9]])
m2 = np.array([[9,8,7],
[6,5,4],
[3,2,1]])
result = m1 + m2
result = m1 - m2
result = m1 * m2
result = np.dot(m1,m2)
result = m1.T
print(result)
2.1.2 java实现
- 在
LinearAlgebra 类中添加五个静态类,实现基本运算 加、减、叉乘、点乘、转置
- 矩阵加法:实现静态方法
LinearAlgebra.add ,要求矩阵a和矩阵b尺寸相同public final static Matrix add(Matrix a,Matrix b){
if(a==null || b==null ||
a.getRowNum() != b.getRowNum() ||
a.getColNum() != b.getColNum()) {
return null;
}
Matrix mat = new Matrix(a.getRowNum(),a.getColNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = a.getValue(i,j).add(b.getValue(i,j));
mat.setValue(i,j,value);
}
}
return mat;
}
- 矩阵减法:实现静态方法
LinearAlgebra.subtract ,要求矩阵a和矩阵b尺寸相同public final static Matrix subtract(Matrix a,Matrix b){
if(a==null || b==null ||
a.getRowNum() != b.getRowNum() ||
a.getColNum() != b.getColNum()) {
return null;
}
Matrix mat = new Matrix(a.getRowNum(),a.getColNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = a.getValue(i,j).subtract(b.getValue(i,j));
mat.setValue(i,j,value);
}
}
return mat;
}
- 矩阵叉乘:实现静态方法
LinearAlgebra.multiply ,要求矩阵a的列数等于矩阵b的行数public final static Matrix multiply(Matrix a,Matrix b){
if(a==null || b==null || a.getColNum() != b.getRowNum()) {
return null;
}
Matrix mat = new Matrix(a.getRowNum(),b.getColNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = new BigDecimal(0.0);
for (int c = 0; c < a.getColNum(); c++) {
value = value.add(a.getValue(i,c).multiply(b.getValue(c,j)));
}
mat.setValue(i,j,value);
}
}
return mat;
}
- 矩阵点乘:实现静态方法
LinearAlgebra.dot ,要求矩阵a和矩阵b尺寸相同public final static Matrix dot(Matrix a,Matrix b){
if(a==null || b==null ||
a.getRowNum() != b.getRowNum() ||
a.getColNum() != b.getColNum()) {
return null;
}
Matrix mat = new Matrix(a.getRowNum(),a.getColNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = a.getValue(i,j).multiply(b.getValue(i,j));
mat.setValue(i,j,value);
}
}
return mat;
}
- 矩阵转置:实现静态方法
LinearAlgebra.transpose public final static Matrix transpose(Matrix a){
if(a==null){
return null;
}
Matrix mat = new Matrix(a.getColNum(),a.getRowNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = a.getValue(i,j);
mat.setValue(j,i,value);
}
}
return mat;
}
- 测试代码
package LinearAlgebra;
public class Main {
public static void main(String[] args) {
Matrix m1 = new Matrix(3,3);
Matrix m2 = new Matrix(3,3);
m1.setValue(new double[][] {{1,2,3},
{4,5,6},
{7,8,9}});
m2.setValue(new double[][] {{9,8,7},
{6,5,4},
{3,2,1}});
Matrix result;
result = AlgebraUtil.add(m1,m2);
result = AlgebraUtil.subtract(m1,m2);
result = AlgebraUtil.dot(m1,m2);
result = AlgebraUtil.multiply(m1,m2);
result = AlgebraUtil.transpose(m1);
System.out.println(result.toString());
}
}
2.2 其他基本操作(生成单位阵、合并、复制)
2.2.1 python实现
- python中,矩阵相当于2维的 ndarray 数组,可以直接使用numpy库方法实现矩阵的全部基本运算
import numpy as np
m1 = np.array([[1,2,3],
[4,5,6],
[7,8,9]])
m2 = np.array([[9,8,7],
[6,5,4],
[3,2,1]])
result = np.eye(3)
result = np.vstack((m1,m2))
result = np.hstack((m1,m2))
result = m1.copy()
print(result)
2.2.2 java实现
- 在
LinearAlgebra 类中添加三个静态类,实现基本操作 生成单位矩阵、矩阵合并、矩阵复制
- 生成单位矩阵:实现静态方法
LinearAlgebra.identityMatrix ,可以用类似的操作构造其他特殊矩阵public final static Matrix identityMatrix(int dimension){
Matrix mat = new Matrix(dimension,dimension);
for (int i = 0; i < mat.getRowNum(); i++) {
mat.setValue(i,i,1.0);
}
return mat;
}
- 矩阵合并:实现静态方法
LinearAlgebra.mergeMatrix ,使用参数direction 决定合并方向(维度),要求合并维度长度一致public final static Matrix mergeMatrix(Matrix a,Matrix b,int direction){
if(direction == 0){
Matrix mat = new Matrix(a.getRowNum()+b.getRowNum(),a.getColNum());
for (int r = 0; r < a.getRowNum(); r++) {
for (int c = 0; c < a.getColNum(); c++) {
BigDecimal value = a.getValue(r, c);
mat.setValue(r,c,value);
}
}
for (int r = 0; r < b.getRowNum(); r++) {
for (int c = 0; c < b.getColNum(); c++) {
BigDecimal value = b.getValue(r, c);
mat.setValue(r+a.getRowNum(),c,value);
}
}
return mat;
} else if (direction == 1){
Matrix mat = new Matrix(a.getRowNum(),a.getColNum()+b.getColNum());
for (int r = 0; r < a.getRowNum(); r++) {
for (int c = 0; c < a.getColNum(); c++) {
BigDecimal value = a.getValue(r, c);
mat.setValue(r,c,value);
}
}
for (int r = 0; r < b.getRowNum(); r++) {
for (int c = 0; c < b.getColNum(); c++) {
BigDecimal value = b.getValue(r, c);
mat.setValue(r,c+a.getColNum(),value);
}
}
return mat;
} else {
return null;
}
}
- 矩阵复制:实现静态方法
LinearAlgebra.copy public final static Matrix copy(Matrix x){
Matrix mat = new Matrix(x.getRowNum(),x.getColNum());
for (int i = 0; i < mat.getRowNum(); i++) {
for (int j = 0; j < mat.getColNum(); j++) {
BigDecimal value = x.getValue(i,j);
mat.setValue(i,j,value);
}
}
return mat;
}
- 测试代码
package LinearAlgebra;
public class Main {
public static void main(String[] args) {
Matrix m1 = new Matrix(3,3);
Matrix m2 = new Matrix(3,3);
m1.setValue(new double[][] {{1,2,3},
{4,5,6},
{7,8,9}});
m2.setValue(new double[][] {{9,8,7},
{6,5,4},
{3,2,1}});
Matrix result;
result = AlgebraUtil.copy(m1);
result = AlgebraUtil.identityMatrix(3);
result = AlgebraUtil.mergeMatrix(m1,m2,0);
result = AlgebraUtil.mergeMatrix(m1,m2,1);
System.out.println(result.toString());
}
}
|