内容
try:
Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
print(f"{model_name} not included!")
exit()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EarlyStopping
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_loss = np.Inf
def __call__(self, val_loss):
"""
if you use other metrics where a higher value is better, e.g. accuracy,
call this with its corresponding negative value
"""
if val_loss < self.best_loss:
early_stop = False
get_better = True
self.counter = 0
self.best_loss = val_loss
else:
get_better = False
self.counter += 1
if self.counter >= self.patience:
early_stop = True
else:
early_stop = False
return early_stop, get_better
def latest_checkpoint(directory):
看一看存储的模型路径名称:
def latest_checkpoint(directory):
if not os.path.exists(directory):
return None
all_checkpoints = {
int(x.split('.')[-2].split('-')[-1]): x
for x in os.listdir(directory)
}
if not all_checkpoints:
return None
return os.path.join(directory,
all_checkpoints[max(all_checkpoints.keys())])
def train()
log_dir:
def train():
writer = SummaryWriter(
log_dir=
f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}
{'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}"
)
if not os.path.exists('checkpoint'):
os.makedirs('checkpoint')
try:
pretrained_word_embedding = torch.from_numpy(
np.load('./data/train/pretrained_word_embedding.npy')).float()
except FileNotFoundError:
pretrained_word_embedding = None
if model_name == 'DKN':
try:
pretrained_entity_embedding = torch.from_numpy(
np.load(
'./data/train/pretrained_entity_embedding.npy')).float()
except FileNotFoundError:
pretrained_entity_embedding = None
try:
pretrained_context_embedding = torch.from_numpy(
np.load(
'./data/train/pretrained_context_embedding.npy')).float()
except FileNotFoundError:
pretrained_context_embedding = None
model = Model(config, pretrained_word_embedding,
pretrained_entity_embedding,
pretrained_context_embedding)
print(torch.cuda.device_count())
if torch.cuda.device_count() > 1:
device_ids = [0, 1]
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.to(device)
if model_name != 'Exp1':
print(model)
else:
print(models[0])
dataset = BaseDataset('data/train/behaviors_parsed.tsv',
'data/train/news_parsed.tsv', 'data/train/roberta')
print(f"Load training dataset with size {len(dataset)}.")
dataloader = iter(
DataLoader(dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
drop_last=True,
pin_memory=True))
if model_name != 'Exp1':
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=config.learning_rate)
else:
criterion = nn.NLLLoss()
optimizers = [
torch.optim.Adam(model.parameters(), lr=config.learning_rate)
for model in models
]
start_time = time.time()
loss_full = []
exhaustion_count = 0
step = 0
early_stopping = EarlyStopping()
checkpoint_dir = os.path.join('./checkpoint', model_name)
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
checkpoint_path = latest_checkpoint(checkpoint_dir)
if checkpoint_path is not None:
print(f"Load saved parameters in {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
early_stopping(checkpoint['early_stop_value'])
step = checkpoint['step']
if model_name != 'Exp1':
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.train()
else:
for model in models:
model.load_state_dict(checkpoint['model_state_dict'])
model.train()
for optimizer in optimizers:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for i in tqdm(range(
1,
config.num_epochs * len(dataset) // config.batch_size + 1),
desc="Training"):
try:
minibatch = next(dataloader)
except StopIteration:
exhaustion_count += 1
tqdm.write(
f"Training data exhausted for {exhaustion_count} times after {i} batches, reuse the dataset."
)
dataloader = iter(
DataLoader(dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
drop_last=True,
pin_memory=True))
minibatch = next(dataloader)
step += 1
y_pred = model(minibatch["candidate_news"],
minibatch["clicked_news"])
y = torch.zeros(len(y_pred)).long().to(device)
loss = criterion(y_pred, y)
loss_full.append(loss.item())
if model_name != 'Exp1':
optimizer.zero_grad()
else:
for optimizer in optimizers:
optimizer.zero_grad()
loss.backward()
if model_name != 'Exp1':
optimizer.step()
else:
for optimizer in optimizers:
optimizer.step()
if i % 10 == 0:
writer.add_scalar('Train/Loss', loss.item(), step)
if i % config.num_batches_show_loss == 0:
tqdm.write(
f"Time {time_since(start_time)}, batches {i}, current loss {loss.item():.4f}, average loss: {np.mean(loss_full):.4f}, latest average loss: {np.mean(loss_full[-256:]):.4f}"
)
if i % config.num_batches_validate == 0:
(model if model_name != 'Exp1' else models[0]).eval()
val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(
model if model_name != 'Exp1' else models[0], './data/val',
200000)
(model if model_name != 'Exp1' else models[0]).train()
writer.add_scalar('Validation/AUC', val_auc, step)
writer.add_scalar('Validation/MRR', val_mrr, step)
writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)
writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)
tqdm.write(
f"Time {time_since(start_time)}, batches {i}, validation AUC: {val_auc:.4f}, validation MRR: {val_mrr:.4f}, validation nDCG@5: {val_ndcg5:.4f}, validation nDCG@10: {val_ndcg10:.4f}, "
)
early_stop, get_better = early_stopping(-val_auc)
if early_stop:
tqdm.write('Early stop.')
break
elif get_better:
try:
torch.save(
{
'model_state_dict': (model if model_name != 'Exp1'
else models[0]).state_dict(),
'optimizer_state_dict':
(optimizer if model_name != 'Exp1' else
optimizers[0]).state_dict(),
'step':
step,
'early_stop_value':
-val_auc
}, f"./checkpoint/{model_name}/ckpt-{step}.pth")
except OSError as error:
print(f"OS error: {error}")
def time_since(since)
def time_since(since):
"""
Format elapsed time string.
"""
now = time.time()
elapsed_time = now - since
return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
if __name__ == '__main__':
print(f'Training model {model_name}')
train()
补充
1. os.listdir() 方法
概述
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)
它不包括 . 和 … 即使它在文件夹中。
只支持在 Unix, Windows 下使用。
语法
listdir()方法语法格式如下:
os.listdir(path)
参数
path – 需要列出的目录路径
返回值
返回指定路径下的文件和文件夹列表。
实例
import os, sys
path = "/var/www/html/"
dirs = os.listdir( path )
for file in dirs:
print (file)
2. Python replace()方法
描述
Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。
语法
replace()方法语法:
str.replace(old, new[, max])
参数
- old – 将被替换的子字符串。
- new – 新字符串,用于替换old子字符串。
- max – 可选字符串, 替换不超过 max 次
返回值
返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。
实例
str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);
thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string
3. datetime测试
print(datetime.datetime.now())
print(datetime.datetime.now().replace(microsecond=0))
print(datetime.datetime.now().replace(microsecond=0).isoformat())
4. NLLLoss 和 CrossEntropyLoss
https://blog.csdn.net/qq_22210253/article/details/85229988
NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数。
在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】
|