IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> do_train -> 正文阅读

[Python知识库]do_train

logging.conf

#logger.conf  -修改

###############################################
#    root:基础
#    std:所有信息输出到标注输出 终端 一般用于调试代码
#    file:所有信息输出到文件 一般用于线上服务 保存日志
[loggers]
keys=root,std,file

[logger_root]
level=DEBUG
handlers=hand01,hand02,hand03

[logger_std]
handlers=hand01,hand02,hand03
qualname=std
propagate=0

[logger_file]
handlers=hand01,hand01_file,hand02_file
qualname=file
propagate=0

###############################################

[handlers]
keys=hand01,hand02,hand03,hand01_file,hand02_file

[handler_hand01]
class=logging.StreamHandler
level=WARNING
formatter=form02
encoding='utf8'
args=(sys.stderr,)


[handler_hand02]
class=logging.StreamHandler
level=ERROR
formatter=form03
encoding='utf8'
args=(sys.stderr,)


[handler_hand03]
class=logging.StreamHandler
level=INFO
formatter=form03
encoding='utf8'
args=(sys.stderr,)


[handler_hand01_file]
class=logging.FileHandler
level=ERROR
formatter=form01
encoding='utf8'
args=('log/error.log', 'a')


[handler_hand02_file]
class=logging.handlers.RotatingFileHandler
level=INFO
formatter=form01
encoding='utf8'
args=('log/log.log', 'a', 500*1024*1024, 5)



###############################################

[formatters]
keys=form01,form02,form03

[formatter_form01]
format=%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s
datefmt=%Y %m %d %H:%M:%S

[formatter_form02]
format=%(name)-12s: %(levelname)-8s %(message)s
datefmt=%Y %m %d %H:%M:%S

[formatter_form03]
format=[%(asctime)s][%(levelname)s]  %(message)s
datefmt=%Y %m %d %H:%M:%S

multinegativeentity.py

# -*- coding: utf-8 -*-
"""
 @about  : 负面实体服务
 @time   : 2020/11/24 10:19
 @author : fangsh
"""
import json
import time
import re
import os
import torch
import torch.nn as nn
import numpy as np
from flask import Flask, request
from transformers import BertModel, BertConfig, BertTokenizer
import logging
import logging.config
from htmllaundry import strip_markup
from utils import entity_replace_multi_key, text_processed, entity_processed

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

logging.config.fileConfig('./config/logging.conf')
logger = logging.getLogger('std')

app = Flask(__name__)


class NeuralNet(nn.Module):
    """定义网络结构: bert对最后一层所有隐向量使用GAP,MAP
    """

    def __init__(self, pretrained_path, num_labels=2):
        super(NeuralNet, self).__init__()
        self.config = BertConfig.from_pretrained(pretrained_path, num_labels=num_labels)
        self.bert = BertModel.from_pretrained(pretrained_path, config=self.config)
        self.dropout = nn.Dropout(.5)
        self.linear = nn.Linear(3072, num_labels)
        self.gap = torch.nn.AdaptiveAvgPool1d(1)
        self.gmp = torch.nn.AdaptiveMaxPool1d(1)

    def forward(self, input_ids, input_mask, input_seg):
        last_hidden_states, pooled_out = self.bert(input_ids=input_ids,
                                                   attention_mask=input_mask,
                                                   token_type_ids=input_seg)  # [batch,max_length,hidden_size],[batch,hidden_size]

        # print(last_hidden_states.shape)
        embedding = last_hidden_states.permute(0, 2, 1)    # [batch,hidden_size,max_length]
        q = self.gap(embedding).squeeze(dim=-1)    # [batch, hidden_size]
        a = self.gmp(embedding).squeeze(dim=-1)    # [batch, hidden_size]
        t = last_hidden_states[:, -1]              # [batch, hidden_size]
        e = last_hidden_states[:, 0]               # [batch, hidden_size]
        x = torch.cat([q, a, t, e], dim=1)
        # x = self.dropout(x)
        out = self.linear(x)
        out = out.view(-1, self.config.num_labels)
        return out


# 定义部分参数
BERT_MODEL_PATH = './RoBERTa_zh_L12_PyTorch'
MAX_LENGTH = 200
NUM_LABELS = 2
MODEL_WEIGHT = './model_save/roberta-datav12_1.bin'

with torch.no_grad():
    model = NeuralNet(BERT_MODEL_PATH, num_labels=NUM_LABELS)
    model.cuda()
    model.eval()
    model.load_state_dict(torch.load(MODEL_WEIGHT))
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH)


def convert_to_feature(text):
    inputs = tokenizer.encode_plus(text, max_length=MAX_LENGTH, truncation=True)
    input_ids, input_mask, input_seg = inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]
    padding_length = MAX_LENGTH - len(input_ids)
    input_ids += [0] * padding_length
    input_mask += [0] * padding_length
    input_seg += [0] * padding_length
    return input_ids, input_seg, input_mask


def muiti_ent_predict(title, entities):
    batch_token_ids, batch_segment_ids, batch_atten_mask = [], [], []
    text = text_processed(title)
    entities = [entity_processed(ent) for ent in entities]
    for ent in entities:
        text_p = entity_replace_multi_key(text, ent, entities)
        logger.info("处理后:%s" % text_p)
        # 转化成数值特征
        token_ids, segment_ids, atten_mask = convert_to_feature(text_p)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)
        batch_atten_mask.append(atten_mask)
    with torch.no_grad():
        batch_ids_tensor = torch.tensor(batch_token_ids, dtype=torch.long)
        batch_seg_tensor = torch.tensor(batch_segment_ids, dtype=torch.long)
        batch_atten_tensor = torch.tensor(batch_atten_mask, dtype=torch.long)
        batch = [batch_ids_tensor, batch_atten_tensor, batch_seg_tensor]
        batch = tuple(t.cuda() for t in batch)
        pred = model(*batch)
        pred = np.argmax(pred.cpu().detach().numpy(), axis=1)
    logger.info(pred)
    pred_ent = []
    for i, j in zip(entities, pred):
        # print(i,j)
        if j == 1:
            pred_ent.append(i)
    # logger.info('ent_pred',pred_ent)
    return pred_ent


@app.route('/upload_and_predict', methods=['GET', 'POST'])
def upload_and_predict():
    result = {
        'code': 0000,
        'msg': '正常',
        'body': [],
        'sysnum': 'null'
    }

    if request.method == 'POST':
        # logging.info('predict begin')
        t1 = time.time()
        try:
            data = request.get_data()
            json_data = json.loads(data.decode('utf-8'))

            # logger.info(text)
            # logger.info(entitys)
        except:
            result['msg'] = 'ERROR: “参数错误”'
            result['code'] = 10001
        logger.info(json_data)
        for i in json_data:
            text = i.get("text")
            entitys = i.get("entitys")
            entitys = entitys.split(',')
            pred_ent = muiti_ent_predict(text, entitys)
            logging.info('text:{}entitys:{}'
                         .format(text, entitys))
            logging.info('pred_entitys:{}time:{:5.2f}'
                         .format(pred_ent, time.time() - t1))
            result['body'].append(pred_ent)
        # except:
        #     result['msg'] = 'ERROR: “模型预测出错”'
        #     result['code'] = 10001

        return json.dumps(result, ensure_ascii=False)
    else:
        result['msg'] = 'ERROR: “只支持post请求”'
        result['code'] = 10001
        return json.dumps(result, ensure_ascii=False)


class MyException(Exception):
    def __init__(self, message):
        super().__init__()
        self.message = message


if __name__ == "__main__":
    app.config['JSON_AS_ASCII'] = False
    app.run(host='0.0.0.0', port=21130, threaded=False)

utils.py

# -*- coding=utf-8 -*-
import re
import pandas as pd
from htmllaundry import strip_markup


def strQ2B(ustring):
    """全角转半角"""
    rstring = ""
    for uchar in ustring:
        inside_code=ord(uchar)
        if inside_code == 12288:                              #全角空格直接转换
            inside_code = 32
        elif (inside_code >= 65281 and inside_code <= 65374): #全角字符(除空格)根据关系转化
            inside_code -= 65248

        rstring += chr(inside_code)
    return rstring

def entity_processed(ent):
    """
    对实体进行处理,避免实体与原文中的不匹配。主要有以下问题:
    1、英文大小写不匹配
    2、()实体是全角,文章被统一为半角
    3、公司的部分字被识别为另一个公司。 全称被替换后,识别不到这个子集。直接做删除处理
    :param ent:
    :return:
    """
    ent = ent.lower()
    ent = ent.replace("(", "(")
    ent = ent.replace(")", ")")
    return ent


def text_processed(text):
    """
    对文本的预处理,是否需要实体在文中的位子增加特殊标记
    :param text:
    :param ent:
    :return:
    """
    text = text.lower()
    text = strip_markup(text)
    text = strQ2B(text)  # 全角转半角符号
    # 去除html转义字符
    text = re.sub(r"&[a-z]+;", "", text)
    text = re.sub(r"&[a-z]+$", "", text)
    text = re.sub(r"\t", "", text)
    text = re.sub(r"\n", "", text)
    # 将中文“” 转化成英文""号
    text = re.sub(r"“", "\"", text)
    text = re.sub(r"”", "\"", text)
    # <> 用于指示实体
    # text = text.replace("<", "")
    # text = text.replace(">", "")
    text = text.replace("@", "")
    # if '_' not in ent:
    #     # 包含_的说明该实体有多个关键词对应 如 龙光地产_3380.hk
    #     cur_title = text.replace(ent, "@" * len(ent))
    # else:
    #     key_words = ent.split('_')
    #     # 如果有多个关键字对应,需要全部标志出来
    #     cur_title = text
    #     for k in key_words:
    #         cur_title = cur_title.replace(k, "@" * len(k))

    return text


def entity_replace_one_key(sentence, cur_entity, entity_list):
    """ 将当前待预测实体在句子中进行替换, 同时考虑长短实体包含的问题,
        如“中国信安”,“青海中国信安”,简单替换“中国信安”会把别的实体也误替换
        只考虑一个实体有一个关键词的情况
    """
    # 只有比当前实体长的实体才会存在被包含的关系,需所以先对实体列表进行排序
    entity_list.sort(key=lambda x: len(x), reverse=True)
    index = entity_list.index(cur_entity)  # 定位当前实体位置
    flag = False
    for i in entity_list[:index]:
        # 判断当前实体是否被其他实体包含
        if cur_entity in i:
            flag = True
    if flag:  # 被包含 就做特殊处理
        for i, entity in enumerate(entity_list[:index]):  # 前面的更长的实体替换,从而避免当前实体的误替换
            sentence = sentence.replace(entity, "<entity_%d>" % i)
        sentence = sentence.replace(cur_entity, "@" * len(cur_entity))
        for i, entity in enumerate(entity_list[:index]):  # 将替换的实体换回来
            sentence = sentence.replace("<entity_%d>" % i, entity)
    else:
        sentence = sentence.replace(cur_entity, "@" * len(cur_entity))
    return sentence


def entity_replace_multi_key(sentence, cur_entity, entity_list):
    """
    考虑每个实体可能对应多个关键字。
    :param sentence:
    :param cur_entity:
    :param entity_list:
    :return:
    """
    entity_list_split = []  # 将实体的关键词全部分开
    for i in entity_list:
        entity_list_split.extend(i.split("_"))
    for i in cur_entity.split("_"):
        sentence = entity_replace_one_key(sentence, i, entity_list_split)
    return sentence


def convert_to_single_ent(df):
    """
    将数据处理成一个实体一行,并且增加是否是负面的标注
    :return:
    """
    data = []
    for i in df.index.values:
        text, all_entities, eng_entities = df.iloc[i]
        text = text_processed(text)
        all_entities, eng_entities = [entity_processed(ent) for ent in str(all_entities).split(",")], \
                                     [entity_processed(ent) for ent in str(eng_entities).split(",")]
        for ent in all_entities:     # 处理成每个实体一行
            text_p = entity_replace_multi_key(text, ent, all_entities)
            if "@" in text_p:
                if ent in eng_entities:  # 当前实体在负面实体列表,标签为1
                    data.append((text_p, ent, 1))
                else:                    # 当前实体不在负面实体列表,标签为0
                    data.append((text_p, ent, 0))

    return pd.DataFrame(data)

run_classify.py

#-*- coding:utf-8 -*-
import torch
import pandas as pd
import numpy as np
import logging
import logging.config

from tqdm import tqdm
from torch import nn
from pytorch_transformers.modeling_bert import BertConfig, BertModel
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers import AdamW
from torch.utils.data import (DataLoader, TensorDataset)
from torch.utils.data import SequentialSampler, RandomSampler, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score,classification_report
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
#from radam.radam import RAdam

logging.config.fileConfig('./config/logging.conf')
logger = logging.getLogger('std')

# 超参设置
max_length = 250
learning_rate = 2e-5
epochs = 15
patience = 5
batch_size = 16
do_train = True
do_eval = True
do_error_analyse = False

train_fpath = './data/processed/train.xlsx'
#train_fpath = './data/processed/训练数据0407.xlsx'
test_fpath = './data/processed/test.xlsx'

#test_fpath = './data/interim/正文负面.xlsx'
bert_model = '/data/sfang/BertPretrainedModel/torch/RoBERTa_zh_L12_PyTorch'
save_model = './model_save/'
file_name = 'roberta-FAQ-model' #采用FAQ思想做

#
class InputExample():
    """定义数据结构:保存每个样例的信息"""
    def __init__(self,text,entity,label):
        self.text = text
        self.entity = entity
        self.label = label

class InputFeature():
    """定义数据结构:保存每个样例的特征。预处理后可直接输入模型的数值特征"""
    def __init__(self,x_ids,x_seg,atten_mask,label):
        self.features = {
            'x_ids':x_ids,
            'x_seg':x_seg,
            'atten_mask':atten_mask
        }
        self.label = label

class NeuralNet(nn.Module):
    """定义网络结构"""
    def __init__(self, pretrained_path, num_labels=2):
        super(NeuralNet, self).__init__()
        self.config = BertConfig.from_pretrained(pretrained_path, num_labels=num_labels)
        self.bert = BertModel.from_pretrained(pretrained_path, config=self.config)
        # self.dropouts = nn.ModuleList(
        #     [nn.Dropout(.5) for _ in range(5)]
        # )
        self.dropout = nn.Dropout(.5)
        self.linear = nn.Linear(self.config.hidden_size * 2, self.config.num_labels)

    def forward(self, x_ids, x_seg, x_mask):
        last_hidden_states, pooled_out = self.bert(input_ids=x_ids,token_type_ids=x_seg,
                                                   attention_mask=x_mask)  # [batch,max_length,hidden_size],[batch,hidden_size]
        pooled_output = self.dropout(pooled_out)
        #position_att = outputs
        target = last_hidden_states * x_seg.float().unsqueeze(-1)
        target = target.sum(dim=1)
        target_div = x_seg.sum(dim=1)
        target = target.div(target_div.float().unsqueeze(-1))
        # target = position_att[0][:,0,:]
        target_cls = torch.cat((target,pooled_output),-1)
        logits = self.linear(target_cls)
        #out = self.linear(self.dropout(pooled_out))
        return logits

# class NeuralNet(nn.Module):
#     """定义网络结构"""
#     def __init__(self, pretrained_path, num_labels=2):
#         super(NeuralNet, self).__init__()
#         self.config = BertConfig.from_pretrained(pretrained_path, num_labels=num_labels)
#         self.bert = BertModel.from_pretrained(pretrained_path, config=self.config)
#         self.dropouts = nn.ModuleList(
#             [nn.Dropout(.5) for _ in range(5)]
#         )
#         # self.dropout = nn.Dropout(.5)
#         self.linear = nn.Linear(self.config.hidden_size, self.config.num_labels)

#     def forward(self, x_ids, x_seg, x_mask):
#         last_hidden_states, pooled_out = self.bert(input_ids=x_ids,token_type_ids=x_seg,
#                                                    attention_mask=x_mask)  # [batch,max_length,hidden_size],[batch,hidden_size]
#         for i, dropout in enumerate(self.dropouts):
#             if i == 0:
#                 out = self.linear(dropout(pooled_out))
#             else:
#                 out += self.linear(dropout(pooled_out))
#         out = out / len(self.dropouts)
#         out = out.view(-1, self.config.num_labels)
#         #out = self.linear(self.dropout(pooled_out))
#         return out


def read_data(fpath):
    """读取表格数据"""
    examples = []
    df = pd.read_excel(fpath, header=None)
    df = df.dropna()
    for i in df.values.tolist():
        ## debug print(i[1])
        text,entity,label = i[1],i[2],i[4]
        examples.append(InputExample(text,entity,label))
    return examples,df

def convert_examples_to_features(examples,tokenizer):
    features = []
    for i,example in enumerate(examples):
        tokenA = tokenizer.tokenize(example.text)
        tokenB = tokenizer.tokenize(example.entity)
        max_text_length = max_length - 3 - len(tokenB)
        tokenA = tokenA[:max_text_length]

        token = ["[CLS]"] + tokenA + ["[SEP]"] + tokenB + ["[SEP]"] #"[CLS]" + A + "[SEP]" + "B" + "[SEP]"
        x_ids = tokenizer.convert_tokens_to_ids(token)
        x_seg = [0] * (len(tokenA) + 2) + [1] * (len(tokenB)+1)
        atten_mask = [1] * len(token)

        padding_lengh = max_length - len(token)
        x_ids += [0] * padding_lengh
        x_seg += [0] * padding_lengh
        atten_mask += [0] * padding_lengh

        label = example.label

        if i < 5 :
            logger.info("********** examples  *********")
            logger.info("ids: {}".format(i))
            logger.info("tokens: {}".format(' '.join(token)))
            logger.info("x_ids:{}".format(' '.join(map(str,x_ids))))
            logger.info("x_seg:{}".format(' '.join(map(str,x_seg))))
            logger.info("atten_mask: {}".format(atten_mask))
            logger.info("label: {}".format(label))

        features.append(InputFeature(x_ids,x_seg,atten_mask,label))
    return features

def select_field(features,field):
    """根据key,返回全部样例的key特征列表"""
    return [ele.features[field] for ele in features]

def metric(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    return acc, f1


tokenizer = BertTokenizer.from_pretrained(bert_model)
all_examples,_ = read_data(train_fpath)
all_features = convert_examples_to_features(examples=all_examples,tokenizer=tokenizer)

all_ids = np.array(select_field(all_features,'x_ids'))
all_seg = np.array(select_field(all_features,'x_seg'))
all_atten_mask = np.array(select_field(all_features,'atten_mask'))
all_labels = [ele.label for ele in all_features]

test_examples,_ = read_data(test_fpath)
test_features = convert_examples_to_features(test_examples,tokenizer)
test_ids = torch.tensor(select_field(test_features,'x_ids'), dtype=torch.long)
test_seg = torch.tensor(select_field(test_features,'x_seg'),dtype=torch.long)
test_atten_mask = torch.tensor(select_field(test_features,'atten_mask'),dtype=torch.long)
test_labels = torch.tensor([ele.label for ele in test_features],dtype=torch.long)

#划分训练集,验证集 8:2
train_ids, valid_ids, train_seg, valid_seg, train_masks, valid_masks, train_labels, valid_labels = \
    train_test_split(all_ids, all_seg, all_atten_mask, all_labels, test_size=0.2, random_state=42)
print(np.array(train_ids).shape)
print(np.array(train_masks).shape)
print(np.array(train_labels).shape)

train_ids_tensor = torch.tensor(train_ids, dtype=torch.long)
train_seg_tensor = torch.tensor(train_seg,dtype=torch.long)
train_masks_tensor = torch.tensor(train_masks, dtype=torch.long)
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)

valid_ids_tensor = torch.tensor(valid_ids, dtype=torch.long)
valid_seg_tensor = torch.tensor(valid_seg,dtype=torch.long)
valid_masks_tensor = torch.tensor(valid_masks, dtype=torch.long)
valid_labels_tensor = torch.tensor(valid_labels, dtype=torch.long)

train_datasets = torch.utils.data.TensorDataset(train_ids_tensor, train_seg_tensor, train_masks_tensor, train_labels_tensor)
valid_datasets = torch.utils.data.TensorDataset(valid_ids_tensor, valid_seg_tensor, valid_masks_tensor, valid_labels_tensor)
test_datasets = torch.utils.data.TensorDataset(test_ids, test_seg, test_atten_mask, test_labels)

train_loader = torch.utils.data.DataLoader(train_datasets, shuffle=True, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_datasets, shuffle=False, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_datasets, shuffle=False, batch_size=batch_size)

early_stoping = 0
best_f1 = 0.0
if do_train:

    logger.info("*************** Train ******************")
    model = NeuralNet(bert_model)
    model.cuda()

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
    # optim = RAdam(optimizer_grouped_parameters, lr=learning_rate, eps=1e-6)
    optim = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-6)
    #loss_fn = torch.nn.BCEWithLogitsLoss()
    weights = torch.tensor([3.0, 1.0],dtype=torch.float).cuda()
    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)

    model.train()
    for epoch in range(epochs):
        # 执行一轮训练
        train_loss = 0
        pbar = tqdm(train_loader)
        for i, batch in enumerate(pbar):
            batch = tuple(t.cuda() for t in batch)
            batch_ids, batch_seg, batch_mask, batch_label = batch
            preb = model(batch_ids, batch_seg, batch_mask)
            loss = loss_fn(preb, batch_label)
            optim.zero_grad()
            loss.backward()
            optim.step()
            train_loss += loss.item() / len(train_loader)
            pbar.set_description("loss%.4f" % loss)
        valid_loss = 0.0

        model.eval()
        valid_preb = np.zeros(shape=(valid_ids_tensor.shape[0], 2))
        with torch.no_grad():
            for i, batch in tqdm(enumerate(valid_loader)):
                batch = tuple(t.cuda() for t in batch)
                ids, segs, masks, labels = batch
                preb = model(ids, segs, masks)
                # labels = labels.view(-1,1)
                # one_hot_label = torch.zeros(batch_size,2).cuda()
                # one_hot_label.scatter_(1,labels,1)
                valid_loss += loss_fn(preb, labels).item() / len(valid_loader)
                valid_preb[batch_size * i:batch_size * (i + 1)] = F.softmax(preb, dim=1).cpu().numpy()

        acc, f1 = metric(valid_labels, np.argmax(valid_preb,axis=1))
        if f1 > best_f1:
            best_f1 = f1
            early_stoping = 0
            torch.save(model.state_dict(), save_model + '%s.bin'%file_name)
        else:
            early_stoping += 1
            if early_stoping >= patience:
                break
        logging.info(
            'epoch: %d, train loss: %.8f, valid loss: %.8f, acc: %.8f, f1: %.8f, best_f1: %.8f\n' %
            (epoch, train_loss, valid_loss, acc, f1, best_f1))
        torch.cuda.empty_cache()  # 每个epoch结束之后清空显存,防止显存不足

if do_eval:
    logger.info("****************** Evaluate *******************")
    preds = []
    y_label = []
    with torch.no_grad():
        model = NeuralNet(bert_model)
        model.load_state_dict(torch.load(save_model + '%s.bin'%file_name))
        model.cuda()
        for i,batch in enumerate(test_loader):
            batch = tuple(t.cuda() for t in batch)
            x_idx, x_seg, atten_mask,label = batch
            y_preb = model(x_idx, x_seg, atten_mask)
            label = label.cpu().numpy()
            y_preb = np.argmax(y_preb.cpu().numpy(),axis=1)
            #print(y_preb)
            preds.extend(y_preb)
            y_label.extend(label)
            # preds[batch_size*i:batch_size*(i+1)] = y_preb
            # labels[batch_size*i:batch_size*(i+1)] = label
    logging.info(classification_report(y_label,preds))
    if do_error_analyse:
        error_samples = []
        for i in range(len(test_examples)):
            if y_label[i] != preds[i]:
                error_samples.append((test_examples[i].text,y_label[i],preds[i]))
        pd.DataFrame(error_samples).to_excel("./output/error_samples.xlsx",index=None,header=['text','label','pred'])

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-21 20:45:38  更:2022-03-21 20:50:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 19:29:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码