2021SC@SDUSC
最后对train.py中的main函数进行分析:
首先是读取数据并且建立数据加载器
def main():
global best_acc
train_labeled_set, train_unlabeled_set, val_set, test_set, n_labels = get_data(
args.data_path, args.n_labeled, args.un_labeled, model=args.model, train_aug=args.train_aug)
labeled_trainloader = Data.DataLoader(
dataset=train_labeled_set, batch_size=args.batch_size, shuffle=True)
unlabeled_trainloader = Data.DataLoader(
dataset=train_unlabeled_set, batch_size=args.batch_size_u, shuffle=True)
val_loader = Data.DataLoader(
dataset=val_set, batch_size=512, shuffle=False)
test_loader = Data.DataLoader(
dataset=test_set, batch_size=512, shuffle=False)
定义模型并且定义优化器
model = MixText(n_labels, args.mix_option).cuda()
model = nn.DataParallel(model)
optimizer = AdamW(
[
{"params": model.module.bert.parameters(), "lr": args.lrmain},
{"params": model.module.linear.parameters(), "lr": args.lrlast},
])
开始训练:
for epoch in range(args.epochs):
train(labeled_trainloader, unlabeled_trainloader, model, optimizer,
scheduler, train_criterion, epoch, n_labels, args.train_aug)
val_loss, val_acc = validate(
val_loader, model, criterion, epoch, mode='Valid Stats')
print("epoch {}, val acc {}, val_loss {}".format(
epoch, val_acc, val_loss))
if val_acc >= best_acc:
best_acc = val_acc
test_loss, test_acc = validate(
test_loader, model, criterion, epoch, mode='Test Stats ')
test_accs.append(test_acc)
print("epoch {}, test acc {},test loss {}".format(
epoch, test_acc, test_loss))
print('Epoch: ', epoch)
print('Best acc:')
print(best_acc)
print('Test acc:')
print(test_accs)
print("Finished training!")
print('Best acc:')
print(best_acc)
print('Test acc:')
print(test_accs)
|