1.训练集
由于平台还没有投入使用,暂时没有数据来源。于是先用minist上的手写数据集进行训练和模型搭建。
MNIST数据集是一个有名的手写数字数据集,在深度学习领域,手写数字识别是一个很经典的学习例子。 使用pytorch直接下载数据集
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=False,
)
# 将训练数据装入Loader中
train_loader = Data.DataLoader(dataset=train_data, batch_size=50, shuffle=True, num_workers=3)
2.公式推导
我们模型选取的是EM+ 混合高斯模型来估计数据的分布。 与k-means一样,给定的训练样本是{x(1),x(2),…x(m)} ,我们将隐含类别标签用z(i)表示。与k-means的硬指定不同,我们首先认为z(i)是满足一定的概率分布的,可以得到联合分布。
p(x(i),z(j)) = p(x(i)|z(j))p(z(j))
我们的模型中有四个参数,u ,q不同组的平均值和协方差矩阵,样本的权重。w概率矩阵。
l(u,q,p) = Σlog p(x(i),u,p,q)
= Σ log Σ p(x(i)|u,q) p(p,z(j))
求导得 进行概率化 w(i,j) = p(z(j)| x(i),u,q,p)
每一轮E步和M步进行更新
|