1. pycuda
??如果安装出错,把下面的路径添加成环境变量。
pip install pycuda
2. 矩阵乘法
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
mod = SourceModule("""
__global__ void matrix_mul(float *dest, float *a, float *b, int l, int m, int n) {
int i = threadIdx.x + blockDim.x * blockIdx.x;
int j = threadIdx.y + blockDim.y * blockIdx.y;
if (i >= n || j >= l)
return;
dest[j * n + i] = 0;
for (int k = 0; k < m; k++)
dest[j * n + i] += a[j * m + k] * b[k * n + i];
}
""")
if __name__ == '__main__':
matrix_mul = mod.get_function("matrix_mul")
mat_a = np.random.randn(4, 40).astype(np.float32)
mat_b = np.random.randn(40, 4).astype(np.float32)
l, m, n = mat_a.shape[0], mat_a.shape[1], mat_b.shape[1]
dest = np.zeros((l, n), dtype=np.float32)
matrix_mul(cuda.Out(dest), cuda.In(mat_a), cuda.In(mat_b),
np.int32(l), np.int32(m), np.int32(n),
block=(n, l, 1), grid=(1, 1))
print(dest)
print()
print(np.dot(mat_a, mat_b))
??第 26 的 l、m、n 需要转换成 numpy 类型,而第 27 行的 l、n 不能转换成 numpy 类型。 ??第 8、9 行的 i、j 是像素坐标,i 的正方向是从左向右,j 的正方向是从上向下。i、j 和第 27 行的 n、l 对应,分别表示横竖方向,所以要注意 n、l 的顺序。 ??第 10 行限制线程坐标,即使第 27 行使用更大维度创建更多线程,多余的线程也不会参加运算。
3. 参考
- 安装 pycuda 出错
- pycuda 教程
- numba 教程
- pycuda 和 numba
|