IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 知识图谱DKN源码详解(四)train.py -> 正文阅读

[人工智能]知识图谱DKN源码详解(四)train.py

内容

try:  #不用多言, 获得该模块下的model_name函数
    Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
    print(f"{model_name} not included!")
    exit()
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EarlyStopping

class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience   
        self.counter = 0
        self.best_loss = np.Inf

    def __call__(self, val_loss):
        """
        if you use other metrics where a higher value is better, e.g. accuracy,
        call this with its corresponding negative value
        """
        # 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它
        if val_loss < self.best_loss:   #如果评测的损失小于最好的损失,那么就是最好的损失
            early_stop = False
            get_better = True
            self.counter = 0
            self.best_loss = val_loss  # 最好的损失 
        else:
            get_better = False         #  
            self.counter += 1
            if self.counter >= self.patience:
                early_stop = True
            else:
                early_stop = False

        return early_stop, get_better  

def latest_checkpoint(directory):

看一看存储的模型路径名称:
在这里插入图片描述

def latest_checkpoint(directory):   #最新的检查点! 
	if not os.path.exists(directory):  #该路径在不在
	        return None
   	all_checkpoints = {   #{10000 : ckpt-10000.pth, 11000: ckpt-11000.pth}  这就是最终的结果
        int(x.split('.')[-2].split('-')[-1]): x
        for x in os.listdir(directory)
    }
    if not all_checkpoints:   #如果没有checkpoint,就返回空
        return None
    return os.path.join(directory,   #我们选择keys最大的选择
                        all_checkpoints[max(all_checkpoints.keys())])

def train()

log_dir:
在这里插入图片描述

def train():
    writer = SummaryWriter(  #这里的路径!  runs/DKN/.....
        log_dir=
       f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}
       {'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}"
    )

    if not os.path.exists('checkpoint'):  #如果没有checkpoint,那么就需要在当前目录下创建checkpoint
        os.makedirs('checkpoint')

    try:
        pretrained_word_embedding = torch.from_numpy(  #读入预训练单词嵌入
            np.load('./data/train/pretrained_word_embedding.npy')).float()
    except FileNotFoundError:
        pretrained_word_embedding = None

    if model_name == 'DKN':   #如果是DKN模型
        try:
            pretrained_entity_embedding = torch.from_numpy(   #如果是DKN,嵌入实体
                np.load(
                    './data/train/pretrained_entity_embedding.npy')).float()
        except FileNotFoundError:
            pretrained_entity_embedding = None

        try:
            pretrained_context_embedding = torch.from_numpy(  #预训练上下文嵌入  但是numpy是在CPU上的! 
                np.load(
                    './data/train/pretrained_context_embedding.npy')).float()
        except FileNotFoundError:
            pretrained_context_embedding = None
        model = Model(config, pretrained_word_embedding,   #创建模型
                      pretrained_entity_embedding,
                      pretrained_context_embedding)
        print(torch.cuda.device_count())   #这里是自己加的,想要实现并行操作! 
        if torch.cuda.device_count() > 1:   #如果设备数目大于1,那么就并行操作
            # model.to(device)
            device_ids = [0, 1]
            model = torch.nn.DataParallel(model, device_ids=device_ids)
            model.to(device)
        # for param in next(model.parameters()):
        #     print(param, param.device)
        # print(next(model.parameters()).device)

    if model_name != 'Exp1':
        print(model)
    else:
        print(models[0])

    dataset = BaseDataset('data/train/behaviors_parsed.tsv',
                          'data/train/news_parsed.tsv', 'data/train/roberta')
    #获得原数据集
	
    print(f"Load training dataset with size {len(dataset)}.")

    dataloader = iter(   #改成dataloader,并被迭代器包装,使得每次访问只需要next()即可  
        DataLoader(dataset,   #由于自己原来接触过dataloader所以这里是懂点的,不再解释
                   batch_size=config.batch_size,
                   shuffle=True,
                   num_workers=config.num_workers,
                   drop_last=True,
                   pin_memory=True))
    if model_name != 'Exp1':
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.learning_rate)
    else:
        criterion = nn.NLLLoss()  #最大似然函数
        optimizers = [             #定义优化器
            torch.optim.Adam(model.parameters(), lr=config.learning_rate)
            for model in models
        ]
    start_time = time.time()    #定义开始的时间
    loss_full = []       #全部损失
    exhaustion_count = 0  #竭尽全力_count???
    step = 0   
    early_stopping = EarlyStopping()  #早点结束,看上面的函数定义

    checkpoint_dir = os.path.join('./checkpoint', model_name)  #检查点/model_name
    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)    #创建checkpoint目录

    checkpoint_path = latest_checkpoint(checkpoint_dir)  #获得最新的检查点
    if checkpoint_path is not None:          #开始带入checkpoint
        print(f"Load saved parameters in {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)   #加载检查点,里面的格式是字典类型的
        early_stopping(checkpoint['early_stop_value'])   #
        step = checkpoint['step']     #
        if model_name != 'Exp1':
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            model.train()
        else:
            for model in models:   
                model.load_state_dict(checkpoint['model_state_dict'])  #直接加载模型参数
                model.train()  
            for optimizer in optimizers:   #直接加载优化器参数
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    for i in tqdm(range(    # epochs * (len(dataset) // config.batch_size + 1)这么多次迭代
            1,
            config.num_epochs * len(dataset) // config.batch_size + 1),
                  desc="Training"):
        try:   #获取小dataloader中的batch
            minibatch = next(dataloader)
            # if torch.cuda.device_count() > 1:
            #     minibatch = torch.nn.DataParallel(minibatch)
            #     minibatch.to(device)
            # minibatch.to(device)
        except StopIteration:  #如果迭代出问题了
            exhaustion_count += 1
            tqdm.write(
                f"Training data exhausted for {exhaustion_count} times after {i} batches, reuse the dataset."
            )
            dataloader = iter(
                DataLoader(dataset,
                           batch_size=config.batch_size,
                           shuffle=True,
                           num_workers=config.num_workers,
                           drop_last=True,
                           pin_memory=True))
            minibatch = next(dataloader)


        step += 1

        y_pred = model(minibatch["candidate_news"],  #结算损失, 候选新闻是预测得到的!
                        minibatch["clicked_news"])

        y = torch.zeros(len(y_pred)).long().to(device)
        loss = criterion(y_pred, y)

        loss_full.append(loss.item())  #要保存损失的
        if model_name != 'Exp1':
            optimizer.zero_grad()
        else:
            for optimizer in optimizers:  #优化器更新权重
                optimizer.zero_grad()
        loss.backward()
        if model_name != 'Exp1':
            optimizer.step()
        else:
            for optimizer in optimizers:
                optimizer.step()

        if i % 10 == 0:   #如果10次计算了,那么就写入我们的损失
            writer.add_scalar('Train/Loss', loss.item(), step)

        if i % config.num_batches_show_loss == 0:  #写出结果
            tqdm.write(
                f"Time {time_since(start_time)}, batches {i}, current loss {loss.item():.4f}, average loss: {np.mean(loss_full):.4f}, latest average loss: {np.mean(loss_full[-256:]):.4f}"
            )

        if i % config.num_batches_validate == 0:   #
            (model if model_name != 'Exp1' else models[0]).eval()
            val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(
                model if model_name != 'Exp1' else models[0], './data/val',
                200000)
            (model if model_name != 'Exp1' else models[0]).train()
            writer.add_scalar('Validation/AUC', val_auc, step)
            writer.add_scalar('Validation/MRR', val_mrr, step)
            writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)
            writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)
            tqdm.write(
                f"Time {time_since(start_time)}, batches {i}, validation AUC: {val_auc:.4f}, validation MRR: {val_mrr:.4f}, validation nDCG@5: {val_ndcg5:.4f}, validation nDCG@10: {val_ndcg10:.4f}, "
            )
			#后面的都是如果是最好的效果,就保存模型参数
            early_stop, get_better = early_stopping(-val_auc)
            if early_stop:
                tqdm.write('Early stop.')
                break
            elif get_better:
                try:
                    torch.save(
                        {
                            'model_state_dict': (model if model_name != 'Exp1'
                                                 else models[0]).state_dict(),
                            'optimizer_state_dict':
                            (optimizer if model_name != 'Exp1' else
                             optimizers[0]).state_dict(),
                            'step':
                            step,
                            'early_stop_value':
                            -val_auc
                        }, f"./checkpoint/{model_name}/ckpt-{step}.pth")
                except OSError as error:
                    print(f"OS error: {error}")

def time_since(since)

def time_since(since):   #运行了多长时间
    """
    Format elapsed time string.
    """
    now = time.time()
    elapsed_time = now - since  #
    return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))


if __name__ == '__main__':
    # print('Using device:', device)
    print(f'Training model {model_name}')
    train()

补充

1. os.listdir() 方法

概述

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)

它不包括 . 和 … 即使它在文件夹中。

只支持在 Unix, Windows 下使用。

语法

listdir()方法语法格式如下:

os.listdir(path)

参数

path – 需要列出的目录路径

返回值

返回指定路径下的文件和文件夹列表。

实例

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import os, sys

# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )

# 输出所有文件和文件夹
for file in dirs:
   print (file)

在这里插入图片描述

2. Python replace()方法

描述

Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。

语法

replace()方法语法:

str.replace(old, new[, max])

参数

  • old – 将被替换的子字符串。
  • new – 新字符串,用于替换old子字符串。
  • max – 可选字符串, 替换不超过 max 次

返回值

返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。

实例

str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);

thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string

3. datetime测试

print(datetime.datetime.now())   #2021-08-27 09:47:48.748545
print(datetime.datetime.now().replace(microsecond=0))  #2021-08-27 09:48:26
print(datetime.datetime.now().replace(microsecond=0).isoformat())  #2021-08-27T09:49:18

4. NLLLoss 和 CrossEntropyLoss

https://blog.csdn.net/qq_22210253/article/details/85229988

NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数

在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-27 11:51:08  更:2021-08-27 11:51:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 18:35:15-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码