使用RNN Module构建的一个字符串转换功能:
import torch
import torch.optim as optim
class Model(torch.nn.Module):
"""
RNN
"""
def __init__(self, input_size, hidden_size, batch_size,num_layers):
super(Model, self).__init__()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.input_size = input_size
self.num_layers = num_layers
#反复使用rnncell, 权重共享
self.rnn = torch.nn.RNN(
input_size=self.input_size,
hidden_size=self.hidden_size)
def forward(self,input, **args):
if 'batch_size' in args:
self.batch_size = args['batch_size']
hidden = torch.zeros(
self.num_layers,
self.batch_size,
self.hidden_size)
out, _= self.rnn(input, hidden)
return out.view(-1, self.hidden_size)
if __name__ == "__main__":
num_layers = 1 # RNN层数
#idx2char = ['e','h','l','o','n','a','b','c'] #构建词典
idx2char = [chr(x) for x in range(ord('A'),ord('Z')+1)] + [chr(x) for x in range(ord('0'),ord('9')+1)] + ['+', '-', '*', '/', '=', ' ']
input_size = len(idx2char) #输入序列每一元素的特征维度
hidden_size = len(idx2char) #隐藏状态维度
print(idx2char)
#输入与标签数据
#x_data = [1,0,5,2,2,3,2,2,4,5] #hellollnnaa
#y_data = [3,1,4,2,3,2,3,3,5,4] #ohlolooaann
x_data = ["xihuanliaojiexuexilehuatuan"]
y_data = ["hifuanliaogaihaxolaofuatuen"]
batch_size = len(x_data) #批次大小
seq_len = len(x_data[0]) #每一批量的序列长度
x_data = [idx2char.index(x) for item in x_data for x in item.upper() ]
y_data = [idx2char.index(x) for item in y_data for x in item.upper()]
print(x_data, y_data)
#词典转换为one-hot对照表
one_hot_lookup = torch.diag(torch.ones(input_size,dtype=torch.int32))
"""
one_hot_lookup = [
[1,0,0,0,0,0],
[0,1,0,0,0,0],
[0,0,1,0,0,0],
[0,0,0,1,0,0],
[0,0,0,0,1,0],
[0,0,0,0,0,1],
]
#x_one_hot = [one_hot_lookup[x] for x in x_data]
"""
x_one_hot = one_hot_lookup[x_data]
print(x_one_hot)
inputs = (x_one_hot.float()).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)
model = Model(input_size, hidden_size, batch_size, num_layers)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
#测试过程
for epoch in range(100):
optimizer.zero_grad()#梯度数据重置
outputs = model(inputs)
loss = criterion(outputs, labels)
#反馈
loss.backward()#反向传播
#更新
optimizer.step()#更新参数
_,idx =outputs.max(dim=1)
print("EPOCH: ", epoch+1, loss.item(), end=" ")
print("Predicted String: ", end=" ")
print("".join([idx2char[x] for x in idx]))
batch_size = 1
myinput = input("请输入你要转换的序列:")
test_x_data = [idx2char.index(x) for x in myinput.upper()]
#新数据
with torch.no_grad(): #无需计算梯度
x_one_hot = one_hot_lookup[test_x_data]
inp = (x_one_hot.float()).view(len(test_x_data), batch_size, input_size)
outputs = model(inp, **{"batch_size":batch_size})
_,idx = outputs.max(dim=1)
print("".join([idx2char[x] for x in idx]),end="")
|