import tensorflow.contrib.layers as lays
import numpy as np
from skimage import transform
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
#定义具有编码器和解码器的网络
def autoencoder(inputs):
net = lays.conv2d(inputs,32,[5,5],stride=2,padding='SAME')
net = lays.conv2d(net,16,[5,5],stride=2,padding='SAME')
net = lays.conv2d(net,8,[5,5],stride=4,padding='SAME')
# 解码器
net = lays.conv2d_transpose(net,16,[5,5],stride=4,padding='SAME')
net = lays.conv2d_transpose(net,32,[5,5],stride=2,padding='SAME')
net = lays.conv2d_transpose(net,1,[5,5],stride=2,padding='SAME',activation_fn=tf.nn.tanh)
return net
def resize_batch(imgs):
#该函数将图像调整为32×32,这样维数可以减少一半
imgs = imgs.reshape((-1,28,28,1))
resized_imgs =np.zeros((imgs.shape[0],32,32,1))
for i in range(imgs.shape[0]):
resized_imgs[i,...,0] = transform.resize(imgs[i, ..., 0],(32,32))
return resized_imgs
##该函数引入高斯噪声
def noisy(image):
row, col = image.shape
mean = 0
var = 0.1
sigma = var**0.5
gauss = np.random.normal(mean,sigma,(row,col))
gauss = gauss.reshape(row,col)
noisy = image + gauss
return noisy
##椒盐噪声先不写
# 定义操作
#与重建信号进行比较的输入
a_e_inputs =tf.placeholder(tf.float32,(None,32,32,1))
#输入网络(MINST图像)
a_e_inputs_noise = tf.placeholder(tf.float32,(None,32,32,1))
a_e_outputs = autoencoder(a_e_inputs_noise)#创建自编码器网络
loss = tf.reduce_mean(tf.square(a_e_outputs - a_e_inputs))# 计算均方误差
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
#初始化网络
init = tf.global_variables_initializer()
batch_size = 1000 #每个批量的样本数
epoch_num = 200 #网络训练的迭代次数
lr = 0.001 #学习率
mnist = input_data.read_data_sets("MNIST_data",one_hot =True)
batch_per_ep = mnist.train.num_examples//batch_size
with tf.Session() as sess:
sess.run(init)
for epoch in range(epoch_num):
batch_img, batch_label=mnist.train.next_batch(batch_size)
batch_img = batch_img.reshape((-1,28,28,1))
batch_img = resize_batch(batch_img)
image_arr = []
for i in range(len(batch_img)):
img = batch_img[i, :, :, 0]
img = noisy(img)
image_arr.append(img)
image_arr = np.array(image_arr)
image_arr = image_arr.reshape(-1,32,32,1)
_, c = sess.run([train_op, loss], feed_dict={a_e_inputs_noise: image_arr, a_e_inputs: batch_img})
print('Epoch:{} - cost = {:.5f}'.format((epoch + 1),c))
#测试训练网络
batch_img, batch_label = mnist.test.next_batch(50)
batch_img = resize_batch(batch_img)
image_arr = []
for i in range(50):
img = batch_img[i,:,:,0]
img = noisy(img)
image_arr.append(img)
image_arr = np.array(image_arr)
image_arr = image_arr.reshape(-1,32,32,1)
reconst_img = sess.run([a_e_outputs],feed_dict={a_e_inputs_noise:image_arr})[0]
#画出重构的图像
plt.figure(1)
plt.title('Input Noisy Images')
for i in range(50):
plt.subplot(5,10,i+1)
plt.imshow(image_arr[i, ...,0],cmap ='gray')
plt.figure(2)
plt.title('Re-constructed Images')
for i in range(50):
plt.subplot(5,10,i+1)
plt.imshow(reconst_img[i, ...,0],cmap ='gray')
plt.show()
|