代码主要是根据原型网络tensorflow1代码改的
源代码地址是https://github.com/sudharsan13296/Hands-On-Meta-Learning-With-Python/tree/master/03.%20Prototypical%20Networks%20and%20its%20Variants
所用数据集链接如下:
链接:https://pan.baidu.com/s/1Tp1KCq5oNv-2tq9Rso4lYw? 提取码:r94v? ?
import os
import glob
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras.layers import (Conv2D,Flatten,MaxPool2D,Dense,BatchNormalization,
Dropout,Activation,Input,GlobalAveragePooling2D)
# GPU相关设置
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 设置GPU按需增长
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
root_dir = 'data/'
train_split_path = os.path.join(root_dir, 'splits', 'train.txt')
with open(train_split_path, 'r') as train_split:
train_classes = [line.rstrip() for line in train_split.readlines()]
#类的数量
no_of_classes = len(train_classes)
#print(no_of_classes) 4112
#数据集中每类样本的数量
num_examples = 20
#图像宽度
img_width = 28
#图像高度
img_height = 28
channels = 1
#所需要一次训练类的数量
num_way = 50
#支撑集中每个类的样本数量
num_shot = 5
#查询集中查询点的数量
num_query = 1
train_dataset = np.zeros([no_of_classes, num_examples, img_height,img_width], dtype=np.float32)
for label, name in enumerate(train_classes):
alphabet, character, rotation = name.split('/')
rotation = float(rotation[3:])
img_dir = os.path.join(root_dir, 'data', alphabet, character)
img_files = sorted(glob.glob(os.path.join(img_dir, '*.png')))
for index,img_file in enumerate(img_files):
values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width,img_height)), np.float32, copy=False)
train_dataset[label, index] = values
print(train_dataset.shape)
#(4112, 20, 28, 28)
#一个卷积块
def convolution_block(Data_Input,conv_filter,conv_stride, name='conv'):
conv = Conv2D(filters = conv_filter,kernel_size = 3,strides = conv_stride,padding = "same",
kernel_initializer = tf.keras.initializers.glorot_uniform(seed = 1))(Data_Input)
conv = MaxPool2D(pool_size = 2,strides = 2,padding = "valid")(conv)
conv = BatchNormalization()(conv)
conv = Activation("relu")(conv)
return conv
#卷积块堆叠提取特征向量
def get_embeddings(Data_Input,conv_filter,conv_stride = 1,reuse=False,remodel=1):
net = convolution_block(Data_Input,conv_filter[0],conv_stride)
net = convolution_block(net,conv_filter[1], conv_stride)
net = convolution_block(net,conv_filter[2], conv_stride)
net = convolution_block(net,conv_filter[3], conv_stride)
net = Flatten(name = "Faltten")(net)
net = models.Model(Data_Input,net)
net.summary()
return net
#定义了嵌入的网络架构
Data_Input = Input((28,28,1))
conv_filter = [1024,512,256,128]
net = get_embeddings(Data_Input,conv_filter)
#定义距离函数(欧氏距离)
def distance(a, b):
N, D = a.shape[0], a.shape[1]
M = b.shape[0]
a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
result = tf.reduce_mean(tf.square(a - b), axis=2)
# result = tf.convert_to_tensor(result, dtype=tf.float32)
return result
loss = tf.keras.metrics.Mean(name='loss')
accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
# test_loss = tf.keras.metrics.Mean(name='test_loss')
# test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
num_epochs = 5
num_episodes = 10
for epoch in range(num_epochs):
loss.reset_states()
accuracy.reset_states()
#test_loss.reset_states()
#test_accuracy.reset_states()
for episode in range(num_episodes):
#选择 50 个类
episodic_classes = np.random.permutation(no_of_classes)[:num_way]
#(50,5,28,28)
support = np.zeros([num_way, num_shot, img_height, img_width], dtype=np.float32)
#(50,5,28,28)
#改为(50,1,28,28)
query = np.zeros([num_way, num_query, img_height, img_width], dtype=np.float32)
# label = np.zeros([1,num_way],dtype=np.int64)
#循环遍历随机打乱的类(已经取了前50个)
for index, class_ in enumerate(episodic_classes):
#每个类总共含有20张照片,随机打乱顺序,然后挑选6个
selected = np.random.permutation(num_examples)[:num_shot + num_query]
#前5个点用于生成支持集
support[index] = train_dataset[class_, selected[:num_shot]]
#每个类 1 个查询点
query[index] = train_dataset[class_, selected[num_shot:]]
# #最后一个维度进行扩张
# support = np.expand_dims(support, axis=-1)
# 去除多余维度,再进行扩张
query = np.squeeze(query)
query = np.expand_dims(query, axis=-1)
# 将标签变为独热编码,去除多余维度
labels = np.tile(np.arange(num_way)[:, np.newaxis], (1,num_query)).astype(np.uint8)
labels = tf.one_hot(labels,50)
labels = np.squeeze(labels)
# print(support.shape) (50, 5, 28, 28, 1)
# print(query.shape) (50, 28, 28, 1)
# print(labels.shape) (50, 50)
print(len(support))
# 创建类原型
support_set_embeddings = np.zeros([num_way,conv_filter[3]],dtype = np.float32)
for i in range(len(support)):
support_set = np.squeeze(support[i,:,:,:])
support_set = np.expand_dims(support_set, axis=-1)
support_set_ = net(support_set)
# print(support_set_)
support_set_embeddings[i] = tf.reduce_mean(support_set_,axis = 0)
support_set_embeddings = tf.convert_to_tensor(support_set_embeddings, dtype=tf.float32)
print(support_set_embeddings.shape)
# 查询集得到特征
query_set = net(query)
# query_set = tf.convert_to_tensor(query_set, dtype=tf.float32)
# query_set = np.array(query_set,dtype=np.float32)
# print(query_set.shape)
# 计算距离
# Distance = distance(support_set_embeddings,query_set)
# Distance = tf.transpose(Distance)
# print(type(Distance))
for i in range(num_way):
# Distance = tf.nn.softmax(Distance)
with tf.GradientTape() as tape:
Distance = distance(support_set_embeddings,query_set)
Distance = tf.transpose(Distance)
Distance = tf.nn.softmax(Distance)
predictions = Distance[i]
loss = loss_object(labels[i],predictions)
print(loss.numpy())
gradients = tape.gradient(loss, net.trainable_variables)
optimizer.apply_gradients(zip(gradients, net.trainable_variables))
loss(loss)
accuracy(label, predictions)
if (episode+1) % 20 == 0:
print('Epoch {} : Episode {} : Loss: {}, Accuracy: {}'.format(epoch+1, episode+1, loss.result(), accuracy.result()*100))
报错如下:
?
ValueError: No gradients provided for any variable: ['conv2d/kernel:0',
'conv2d/bias:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0',
'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'batch_normalization_1/gamma:0',
'batch_normalization_1/beta:0', 'conv2d_2/kernel:0', 'conv2d_2/bias:0',
'batch_normalization_2/gamma:0', 'batch_normalization_2/beta:0', 'conv2d_3/kernel:0',
'conv2d_3/bias:0', 'batch_normalization_3/gamma:0', 'batch_normalization_3/beta:0'].
?
深度学习小白,实在不知道错误原因是啥,要是哪位大佬需要用到,可以改改试试,能指正一下就十分感谢啦!
|