1. API讲解
??TextCNN做文本分类,所用到的API有两种,一种是继续用tf.nn.conv2d和tf.nn.max_pool/tf.nn.avg_pool。另外一种API是tf.nn.conv1d和tf.nn.max_pool1d/tf.nn.avg_pool1d。如果继续使用tf.nn.conv2d和tf.nn.max_pool/tf.nn.avg_pool虽然API参数要求是一样的,不过对于某些参数的某些维度的值会固定下来。具体可看下方参数详解
- tf.nn.conv2d:
- 参数:
- input:输入,要求的shape为[batch, in_height, in_width, in_channels],由于文本分类的输入shape一般为[batch, seq_length, dim],故需要先将输入expand_dim成[batch, seq_length, dim, 1],这样就满足input参数的要求了
- filter:[filter_height, filter_width, in_channels, out_channels],out_channels表示有几个filter,[filter_heigth,in_width]被称为kernel_size。在文本分类中,filter_height有点像n-gram中的n,filter_width固定为词向量的维度dim,in_channels和input的in_channels一样,为1,out_channels可以任
|