CDIMC-Net[1] 中有个对整个数据集求 kNN 图的函数 get_kNNgraph2 [2],是用 dense 的 numpy.ndarray 存的,空间复杂度
O
(
n
2
)
O(n^2)
O(n2),大数据集很吃内存,但其实 kNN 图很稀疏。这里用 scipy.sparse 的 API 改写。
Code
- csr_matrix:row slicing 高效,因为一行对应一个 datum 的邻接链表,取 batch 是对行取,所以用它。
- lil_matrix:说是「改变稀疏结构很高效」,用在图的构造时,构造完再转
csr_matrix (本来直接用 csr_matrix 构造,然后它建议用 lil_matrix )。
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
def get_kNNgraph2(data,K_num):
"""原来的构图函数
https://github.com/DarrenZZhang/CDIMC-Net/blob/main/CDIMC-net-handwritten_final.py#L46
"""
x_norm = np.reshape(np.sum(np.square(data), 1), [-1, 1])
x_norm2 = np.reshape(np.sum(np.square(data), 1), [1, -1])
dists = x_norm - 2 * np.matmul(data, np.transpose(data))+x_norm2
num_sample = data.shape[0]
graph = np.zeros((num_sample,num_sample),dtype = np.int)
for i in range(num_sample):
distance = dists[i,:]
small_index = np.argsort(distance)
graph[i,small_index[0:K_num]] = 1
graph = graph-np.diag(np.diag(graph))
resultgraph = np.maximum(graph,np.transpose(graph))
return resultgraph
def get_kNNgraph2_sparse(X, K_num, batch_size=256):
"""sparse version of kNN graph calculation"""
n = X.shape[0]
G = lil_matrix((n, n), dtype=np.int8)
x_norm_all = np.sum(np.square(X), axis=1, keepdims=True).T
for _begin in range(0, n, batch_size):
_end = min(_begin + batch_size, n)
X_batch = X[_begin: _end]
x_norm = np.sum(np.square(X_batch), axis=1, keepdims=True)
D = x_norm - 2 * np.matmul(X_batch, np.transpose(X)) + x_norm_all
small_index = np.argsort(D, axis=1)[:, :K_num]
for i in range(small_index.shape[0]):
_row_id = _begin + i
_small_idx = small_index[i]
G[_row_id, _small_idx] = 1
G.setdiag(0)
G = G.maximum(G.transpose())
G = G.tocsr()
return G
"""验证一致性"""
N = 6
D = 3
K = N // 2
for i in range(150):
X = np.random.permutation(N * D).reshape(N, D)
G1 = get_kNNgraph2(X, K)
G2 = get_kNNgraph2_sparse(X, K).todense()
diff = (G1 != G2).sum()
if diff != 0:
print("diff:", i, diff)
print("DONE")
References
- DarrenZZhang/CDIMC-Net
- get_kNNgraph2
- Sparse matrices (scipy.sparse)
- scipy.sparse.csr_matrix
- scipy.sparse.lil_matrix
- torch.sparse
|