IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于CNN-RNN的医疗文本生成 -> 正文阅读

[人工智能]基于CNN-RNN的医疗文本生成

🐱 基于CNN-RNN的医疗文本生成

本项目使用经过IMAGENET预训练的resnet101网络对图像特征进行提取后,
将图像特征输入LSTM来生成影像的文本描述。

初步实现了图像到文本的简单生成。


📖 0 项目背景

随着近年来深度学习的飞速发展,深度学习在医疗行业展现出巨大的发展潜力。因此,如果能通过深度学习的方法,使用计算机代替医生进行机械的影像报告撰写工作,这样既避免了经验不足的医生在阅片诊断中产生的误诊情况,又使得更多的资深医生可以从繁重的重复性工作中解脱出来,将更多的时间投入病人的诊治中去。

医学影像报告自动生成是近年来计算机与医疗图像新兴的交叉方向之一。目前,影像报告自动生成模型主要借鉴了机器翻译领域的 Encoder-Decoder 框架,利用卷积
神经网络(Convolutional Neural Network, CNN)对图像特征进行提取进而利用循环神经网络(Recurrent Neural Network, RNN)来生成影像的文本描述

📌 1 数据集

印第安纳大学胸部 X 射线集合 (IU X 射线) 是一组胸部 X 射线图像及其相应的诊断报告。该数据集包含 7,470 对图像和报告(6470:500:500)。 每个报告由以下部分组成:印象、发现、标签、比较和指示。平均每张图像关联2.2个标签,5.7个句子,每个句子包含6.5个单词。

本项目仅使用FINDINGS部分作为图像生成标签

参考代码:

  • https://github.com/chenyuntc/pytorch-book/tree/master/chapter10-image_caption

🐰 2 数据集生成

🐅 2.1 医疗文本CSV生成


解压原始数据,对xml格式的数据进行解析,提取图像文件名和对应的FINDINGS,并生成CSV文件。

# 解压数据集
!unzip  -o  data/data123482/IU数据集.zip -d /home/aistudio/work/
  inflating: /home/aistudio/work/IU数据集/NLMCXR_reports/ecgen-radiology/1504.xml  
# 取消警告的输出
import warnings
warnings.filterwarnings("ignore")
## 制作CSV数据集
# 平均字符数为 31.64992700729927
import os
import glob
import pandas as pd
from xml.dom import minidom
import re
import numpy as np
LENGTH = []

def EmptyDrop(data):
    for i in range(len(data)):
        if data.loc[i,'dir'] ==[] or data.loc[i,'caption'] ==[]:
            #如果为空,则删除该行
            data.drop([i],axis = 0,inplace = True) 
        else:
            data.loc[i,'dir'] = data.loc[i,'dir'][0]
            data.loc[i,'caption'] = data.loc[i,'caption'][0]
    data.reset_index(drop = True,inplace = True)
    return data

def clean_text(origin_text):
    # 去掉标点和非法字符
    text = re.sub("^a-zA-Z"," ",origin_text)
    #大写改小写
    cleaned_text = text.lower()
    return cleaned_text
def xml2csv(path):
    num = 0
    column_name = ['dir','caption']
    xml_csv = pd.DataFrame(columns = column_name)
    #图片保存地址
    pic_path =  'work/IU数据集/NLMCXR_png'
    for xml_file in glob.glob(path+'/*.xml'):

        #记录每个xml需要保存的所有信息 fx 地址 IMPRESSION FINDINGS 
        xml_list = []
        #打开xml文档
        dom = minidom.parse(xml_file)
        #得到文档元素对象
        root = dom.documentElement
        # f1 地址
        itemlists=root.getElementsByTagName('parentImage')
        #记录地址
        dirAll = []
        for itemlist in itemlists:
            figureId=itemlist.getElementsByTagName('figureId')
            #找出该图片的figureID
            figure = figureId[0].childNodes[0].nodeValue
            #找出该图片的名称
            ID= itemlist.getAttribute('id')
            IdPath = ID
            #正面图&侧面图
            figurePath = [figure+' '+IdPath]
            dirAll.extend(figurePath)
        xml_list.append(dirAll)

        #记录FINDINGS and IMPRESSION
        #记录内容
        CaptionAll = []
        itemlists=root.getElementsByTagName('AbstractText')
        for i in range(len(itemlists)):
            Label= itemlists[i].getAttribute('Label')
            if Label == 'FINDINGS': # or Label == 'IMPRESSION':
                # 内容不为空
                if len(itemlists[i].childNodes)!=0:
                    text = itemlists[i].childNodes[0].nodeValue
                    #转小写,过滤无效字符
                    text = clean_text(text)
                    text = text.replace('.','')
                    text = text.replace(',','')
                    text = [text+'']                    
                    CaptionAll.extend(text)
        if len(CaptionAll)>=1:
            LENGTH.append(len(CaptionAll[0].split(' ')))
        xml_list.append(CaptionAll)
        xml_csv.loc[num] = [item for item in xml_list]
        num = num + 1
        print('epoch[{}/{}]'.format(num,len(glob.glob(path+'/*.xml'))))
   # print(np.mean(LENGTH))
    return xml_csv

def main():
    xml_path = os.path.join('work','IU数据集','NLMCXR_reports','ecgen-radiology')
    csv = xml2csv(xml_path)
    csv1 = EmptyDrop(csv)
    csv1.to_csv('work/IUxRay.csv',index=None)
    
if __name__ == '__main__':
    main()

🐅 2.2 图像特征提取


  • 使用ImageNet预训练的resnet101模型提取图像特征(删除最后的全连接层,改为恒等映射)。

  • 将数据保存为h5文件

## 使用resnet101预训练模型提取图像特征
import paddle
from paddle.vision.models import resnet101
import h5py
import cv2

csv_file = pd.read_csv('work/IUxRay.csv')
h5_png_file =list(csv_file['dir'])

# 创建保存目录
save_path = 'work/util_IUxRay'
if not os.path.exists(save_path):
    os.makedirs(save_path)

# 导入模型resnet101 使用谷歌预训练
model = resnet101(pretrained=True)
# 删除最后的全连接层
del model.fc
model.fc = lambda x:x

h5f = h5py.File(os.path.join(save_path,'resnet101_festures.h5'), 'w')

for idx,item in enumerate(h5_png_file):
    # 读取所有PNG(F1,F2...)
    print(idx,len(h5_png_file))
    item_all = item.split(',')
    for item_t in item_all:
        item_t = item_t.replace('\'','').replace('[','').replace(']','')
        # 对不同朝向的图进行区分
        for orie in ['F1','F2','F3','F4']:
            if orie in item_t:
                orie_fin = orie
                item_fin = item_t.replace(orie,'').replace(' ','')
                item_fin_png = item_fin + '.png'
                print(orie_fin + '_' + item_fin)
                # 读取文件送入模型提取特征并保存为h5
                img = cv2.imread(os.path.join('work/IU数据集/NLMCXR_png',item_fin_png))
                # BGR转RGB,且HWC转CHW
                img = img[:,:,::-1].transpose((2,0,1))
                # 扩展维度
                img = np.expand_dims(img, 0)
                img_tensor = paddle.to_tensor(img,dtype='float32' ,place=paddle.CPUPlace())
                # 进行特征提取
                out = model(img_tensor)
                data = out.numpy().astype('float32')
                # 保存的数据为h5
                save_path_h5 = data[0]
                h5f.create_dataset(orie_fin + '_' + item_fin, data=save_path_h5)
h5f.close()

# 读取h5文件
import h5py
h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
data = h5f['F1_CXR3027_IM-1402-1001'] # 第一个下划线 之前为图片朝向,之后为图像原命名
print(np.array(data).shape) # 每个图片保存为一个2048维度的向量
h5f.close() 

🐅 2.3 字典生成


  • 统计训练数据,按照单词进行分割创建字典。

  • 字典修正:删除仅在数据集中出现过一次的单词

# 统计训练数据,以单词为粒度创建字典
import pandas as pd
import numpy as np
import re

csv_file = pd.read_csv('work/IUxRay.csv')
csv_file.head()

CaptionWordAll = []
CaptionWordLength = []
for idx,data_ in enumerate(csv_file.iterrows()):
    caption = data_[1][1]
    CaptionWordLength.append(len(caption.split(' ')))
    CaptionWordAll.extend(caption.split(' '))

print('平均句子长度为:',np.mean(CaptionWordLength))
print('最大句子长度为:',np.max(CaptionWordLength))
print('最小句子长度为:',np.min(CaptionWordLength))

print('单词总量为:',len(CaptionWordAll))
print('字典长度为:',len(set(CaptionWordAll)))

# 100

from collections import Counter
# 统计频率,按照从高到底排序,这样构建的字典使用频率最高的符号在最前面,查找起来快
counts = Counter(CaptionWordAll)
count_sorted = counts.most_common()
count_sorted_ = {k: v for k, v in count_sorted if v > 1}

# 构造字典
# 增加 <pad> 0 <unk> 1 <start> 2 <end> 3  四个作为常用符号

word2id_dict={'<pad>':0,'<unk>':1,'<start>':2,'<end>':3}
id2word_dict={0:'<pad>',1:'<unk>',2:'<start>',3:'<end>'}

for idx,item in enumerate(count_sorted_):
    idx_ = idx+4 # 预留四个做为记录
    item_ = item
    word2id_dict[item_] = idx_
    id2word_dict[idx_] = item_

# 删除只出现一次的单词
print('修正后字典长度为:',len(word2id_dict))

🥝 3 定义数据读取类


  • 将数据按照8:2划分为训练集和验证集。

  • 将文本数据经过字典进行映射,不同于翻译任务,本任务用图像特征替代了(85行)。

## 完成dataload

import paddle
from paddle.io import Dataset
import numpy as np
from sklearn.model_selection import train_test_split

# 重写数据读取类
class CaptionDataset(Dataset):
    # 构造函数,定义函数参数
    def __init__(self,csvData,word2id_dict,h5f,maxlength = 40,mode = 'train'):
        self.mode = mode
        self.w2i_dict = word2id_dict
        self.maxlength = maxlength # 输入的最长字符数
        self.padid = 0 # 0为填充符号
        self.h5f = h5f
        # 根据train/test 将数据按比例处理
        train,test =csvData.iloc[:int(0.8*len(csvData)),:],csvData.iloc[int(0.8*len(csvData)):,:] #train_test_split(csvData,train_size=0.8,random_state=10)

        if self.mode == 'train':
            train.reset_index(drop=True)
            self.data = train
        else:
            test.reset_index(drop=True)
            self.data = test

    # 实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
    def __getitem__(self, index):

        path_name, trg_  = self.data.iloc[index,:]
        
        # 读取图像的特征
        temp  = path_name.split(' ')
        names = '_'.join(temp)
        img_feature = np.array(self.h5f[names]) # 第一个下划线 之前为图片朝向,之后为图像原命名

        # 输入转成idx
        trg,trg_length = self.generIdxList(trg_) # data
        img_name = temp[-1]
        return img_feature,trg,trg_length,img_name

    # 实现__len__方法,返回数据集总数目
    def __len__(self):
        return len(self.data)
    
    def generIdxList(self,tdata):
        # 从输入的String中,生成idx的List
        data = tdata.split(' ')
        data_out = []
        # 限制长度,输入'<start>' 和 '<end>'
        data = ['<start>'] + data 
        
        if len(data)>self.maxlength-1:
            data = data[:self.maxlength-1] # 留一个位置给'<end>
            data = data + ['<end>']

        else:
            # 占位符
            occupy_ = ['<pad>'] * (self.maxlength - 1 - len(data))
            data = data + ['<end>']
            data = data + occupy_

        # word 2 index
        for word in data:

            if self.w2i_dict.get(word)!= None: # 能找到word
                id_ = self.w2i_dict[word]
                data_out.append(id_)
            else:
                id_ = self.w2i_dict['<unk>']
                data_out.append(id_)

        length = len(data_out)-1

        return data_out,length

def stackInput(inputs):

    img_features = np.stack([inputsub[0] for inputsub in inputs], axis=0)
    trg = np.stack([inputsub[1] for inputsub in inputs], axis=0)
    trg_length = np.stack([inputsub[2] for inputsub in inputs], axis=0)


    trg_mask =(trg[:,:-1]!=0).astype(paddle.get_default_dtype())
    trg_ = trg[:,1:] # 将start标记更改为 imgfeatures
    return img_features,trg_length,trg_[:,:-1],trg[:,1:,np.newaxis],trg_mask
# 测试数据读取
import pandas as pd
import numpy as np
import h5py
from sklearn.model_selection import train_test_split
csvData = pd.read_csv('work/IUxRay.csv')
h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')

maxlength = 40

dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
data_loader = paddle.io.DataLoader(dataset, batch_size=1,collate_fn = stackInput, shuffle=False)

for item in data_loader:
    print(item[0].shape,item[1].shape,item[2].shape,item[3].shape,item[4].shape)
    break

💡 4 定义模型


  • 定义LSTM模型用于文本生成

  • 定义beam search算法对生成结果进行优化

# 定义模型
import paddle.nn as nn
import paddle
class CaptionModel(paddle.nn.Layer):

    def __init__(self, vocab_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict):
        
        super(CaptionModel,self).__init__()

        self.hidden_size=hidden_size
        self.num_layers=num_layers        
        self.fc = paddle.nn.Linear(2048,embedding_dim)
        self.embedding=paddle.nn.Embedding(vocab_size,embedding_dim)
        self.rnn=paddle.nn.LSTM(input_size=embedding_dim,
                                hidden_size=hidden_size,
                                num_layers=num_layers)
        self.word2ix = word2id_dict
        self.ix2word = id2word_dict
        self.classifier = paddle.nn.Linear(hidden_size,vocab_size)

    def forward(self,img_features,trg,trg_length):
        img_features = paddle.unsqueeze(self.fc(img_features),axis = 1)
        embeddings = self.embedding(trg)
        inputs = paddle.concat([img_features,embeddings],axis = 1)
        outputs,state = self.rnn(inputs,sequence_length  = trg_length)
        predict = self.classifier(outputs)

        return predict

    def generate(self, img_feat, eos_token='<end>',
                 beam_size=2,
                 max_caption_length=40,
                 length_normalization_factor=0.0):
        """
        根据图片生成描述,主要是使用beam search算法以得到更好的描述
        """
        cap_gen = CaptionGenerator(embedder=self.embedding,
                                   rnn=self.rnn,
                                   classifier=self.classifier,
                                   eos_id=self.word2ix[eos_token],
                                   beam_size=beam_size,
                                   max_caption_length=max_caption_length,
                                   length_normalization_factor=length_normalization_factor)

        img_feat = paddle.unsqueeze(img_feat,axis = 0)
        img = paddle.unsqueeze(self.fc(img_feat),axis = 0)

        sentences, score = cap_gen.beam_search(img)

        sentences = [' '.join([self.ix2word[int(idx)] for idx in sent])
                     for sent in sentences]
        return sentences
# Beam Search
import paddle.nn as nn
import heapq

class TopN(object):
    """Maintains the top n elements of an incrementally provided set."""

    def __init__(self, n):
        self._n = n
        self._data = []

    def size(self):
        assert self._data is not None
        return len(self._data)

    def push(self, x):
        """Pushes a new element."""
        assert self._data is not None
        if len(self._data) < self._n:
            heapq.heappush(self._data, x)
        else:
            heapq.heappushpop(self._data, x)

    def extract(self, sort=False):
        """
        Extracts all elements from the TopN. This is a destructive operation.
        The only method that can be called immediately after extract() is reset().
        Args:
          sort: Whether to return the elements in descending sorted order.
        Returns:
          A list of data; the top n elements provided to the set.
        """
        assert self._data is not None
        data = self._data
        self._data = None
        if sort:
            data.sort(reverse=True)
        return data

    def reset(self):
        """Returns the TopN to an empty state."""
        self._data = []

class Caption(object):
    """Represents a complete or partial caption."""

    def __init__(self, sentence, state, logprob, score, metadata=None):
        """Initializes the Caption.
        Args:
          sentence: List of word ids in the caption.
          state: Model state after generating the previous word.
          logprob: Log-probability of the caption.
          score: Score of the caption.
          metadata: Optional metadata associated with the partial sentence. If not
            None, a list of strings with the same length as 'sentence'.
        """
        self.sentence = sentence
        self.state = state
        self.logprob = logprob
        self.score = score
        self.metadata = metadata

    def __cmp__(self, other):
        """Compares Captions by score."""
        assert isinstance(other, Caption)
        if self.score == other.score:
            return 0
        elif self.score < other.score:
            return -1
        else:
            return 1

    # For Python 3 compatibility (__cmp__ is deprecated).
    def __lt__(self, other):
        assert isinstance(other, Caption)
        return self.score < other.score

    # Also for Python 3 compatibility.
    def __eq__(self, other):
        assert isinstance(other, Caption)
        return self.score == other.score

class CaptionGenerator(object):
    """Class to generate captions from an image-to-text model."""

    def __init__(self,
                 embedder,
                 rnn,
                 classifier,
                 eos_id,
                 beam_size=3,
                 max_caption_length=100,
                 length_normalization_factor=0.0):
        """Initializes the generator.
        Args:
          model: recurrent model, with inputs: (input, state) and outputs len(vocab) values
          beam_size: Beam size to use when generating captions.
          max_caption_length: The maximum caption length before stopping the search.
          length_normalization_factor: If != 0, a number x such that captions are
            scored by logprob/length^x, rather than logprob. This changes the
            relative scores of captions depending on their lengths. For example, if
            x > 0 then longer captions will be favored.
        """
        self.embedder = embedder
        self.rnn = rnn
        self.classifier = classifier
        self.eos_id = eos_id
        self.beam_size = beam_size
        self.max_caption_length = max_caption_length
        self.length_normalization_factor = length_normalization_factor

    def beam_search(self, rnn_input, initial_state=None):
        """Runs beam search caption generation on a single image.
        Args:
          initial_state: An initial state for the recurrent model
        Returns:
          A list of Caption sorted by descending score.
        """

        def get_topk_words(embeddings, state):
            output, new_states = self.rnn(embeddings, state)
            output = self.classifier(paddle.squeeze(output,axis=0))
            logprobs = nn.functional.log_softmax(output, axis=-1)

            if len(logprobs.shape) == 3:
                logprobs = paddle.squeeze(logprobs)
            logprobs, words = logprobs.topk(self.beam_size, 1)
            return words, logprobs, new_states

        partial_captions  = TopN(self.beam_size)
        complete_captions = TopN(self.beam_size)

        words, logprobs, new_state = get_topk_words(rnn_input, initial_state)

        for k in range(self.beam_size):
            cap = Caption(
                sentence=[words[0, k]],
                state=new_state,
                logprob=logprobs[0, k],
                score=logprobs[0, k])

            partial_captions.push(cap)

        # Run beam search.

        for _ in range(self.max_caption_length - 1):
            partial_captions_list = partial_captions.extract()
            partial_captions.reset()
            input_feed =[c.sentence[-1] for c in partial_captions_list]

            input_feed = paddle.to_tensor(input_feed)
            state_feed = [c.state for c in partial_captions_list]
            if isinstance(state_feed[0], tuple):
                state_feed_h, state_feed_c = zip(*state_feed)
                state_feed = (paddle.concat(state_feed_h, 1),
                              paddle.concat(state_feed_c, 1))
            else:
                state_feed = paddle.concat(state_feed, 1)
            

            embeddings = self.embedder(input_feed)

            words, logprobs, new_states = get_topk_words( 
                embeddings, state_feed)
                
            for i, partial_caption in enumerate(partial_captions_list):
                if isinstance(new_states, tuple):
                    state = (paddle.slice(new_states[0],axes=[1],starts=[i],ends = [i+1]),
                             paddle.slice(new_states[1],axes=[1],starts=[i],ends = [i+1]))
                else:
                    state = new_states[i]
                for k in range(self.beam_size):
                    w = words[i, k]
                    sentence = partial_caption.sentence + [w]
                    logprob = partial_caption.logprob + logprobs[i, k]
                    score = logprob
                    if w == self.eos_id:
                        if self.length_normalization_factor > 0:
                            score /= len(sentence)**self.length_normalization_factor
                        beam = Caption(sentence, state, logprob, score)
                        complete_captions.push(beam)
                    else:
                        beam = Caption(sentence, state, logprob, score)
                        partial_captions.push(beam)
            if partial_captions.size() == 0:
                # We have run out of partial candidates; happens when beam_size
                # = 1.
                break

        # If we have no complete captions then fall back to the partial captions.
        # But never output a mixture of complete and partial captions because a
        # partial caption could have a higher score than all the complete
        # captions.
        if not complete_captions.size():
            complete_captions = partial_captions

        caps = complete_captions.extract(sort=True)

        return [c.sentence for c in caps], [c.score for c in caps]

🥝 5 定义损失函数


  • 使用基本的交叉熵损失函数

  • 使用定义的trg_mask避免对padding部分求loss

# 定义损失函数
class CrossEntropy(paddle.nn.Layer):
    def __init__(self):
        super(CrossEntropy,self).__init__()

    def forward(self,pre,real,trg_mask):

        cost=paddle.nn.functional.softmax_with_cross_entropy(logits=pre,label=real)
        
        # 删除axis=2 shape上为1的维度
        cost=paddle.squeeze(cost,axis=[2])
        
        # trg_mask 的形状[batch_size,suqence_len]
        masked_cost=cost*trg_mask
        
        return paddle.mean(paddle.mean(masked_cost,axis=[0]))

🦃 6 定义参数并训练


  • 增加困惑度作为评价指标

  • 设置训练参数

# 参数
import h5py

epochs=60
word_size = 1151
eos_id=word2id_dict['<end>']
num_layers=32
hidden_size=512
embedding_dim=512
lr=1e-3
maxlength=40
model_path='./output'

csvData = pd.read_csv('work/IUxRay.csv')
h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')

import paddlenlp

model=CaptionModel(word_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict)

optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())
# 困惑度
ppl_metric=paddlenlp.metrics.Perplexity()


train_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
train_loader = paddle.io.DataLoader(train_dataset, batch_size=128,collate_fn = stackInput, shuffle=True)

val_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'test')
val_loader = paddle.io.DataLoader(val_dataset, batch_size=64,collate_fn = stackInput, shuffle=True)


# 设置优化器
optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())
# 设置损失函数
loss_fn = CrossEntropy()

perplexity = paddlenlp.metrics.Perplexity()

model.train()

for epoch in range(epochs):
    for batch_id, data in enumerate(train_loader()):

        img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]  # 数据

        predicts = model(img_features,inputs,trg_length)    # 预测结果

        # 计算损失 等价于 prepare 中loss的设置
        loss = loss_fn(predicts, label , label_mask)

        # 计算困惑度 等价于 prepare 中metrics的设置
        correct = perplexity.compute(predicts, label)
        perplexity.update(correct.numpy())
        ppl = perplexity.accumulate()
        
        # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中

        # 反向传播
        loss.backward()

        if (batch_id+1) % 20 == 0:
            print("epoch: {}, batch_id: {}, loss is: {}, ppl is: {}".format(epoch+1, batch_id+1, loss.item(), ppl))

            # 保存模型参数,文件名为Unet_model.pdparams
            paddle.save(model.state_dict(), 'work/LSTM_model.pdparams')

        # 更新参数
        optimizer.step()

        # 梯度清零
        optimizer.clear_grad()

model.eval()

for batch_id, data in enumerate(val_loader()):

    img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]  # 数据

    predicts = model(img_features,inputs,trg_length)    # 预测结果

    # 计算损失 等价于 prepare 中loss的设置
    loss = loss_fn(predicts , label , label_mask)

    # 计算困惑度 等价于 prepare 中metrics的设置
    correct = perplexity.compute(predicts, label)
    perplexity.update(correct.numpy())
    ppl = perplexity.accumulate()
    
    # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中

    if (batch_id+1) % 1 == 0:
        print(" batch_id: {}, loss is: {}, ppl is: {}".format( batch_id+1, loss.item(), ppl))

🍓 7 模型推理


# 验证数据集
from IPython.display import display
from PIL import Image
import numpy as np
from tqdm import tqdm

path = 'work/IU数据集/NLMCXR_png/'
csvData = pd.read_csv('work/IUxRay.csv')
h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
data = csvData.iloc[int(0.8*len(csvData)):,:]
scores = []

Beam_Size = 3
for idx,data_ in tqdm(enumerate(data.iterrows())):
    F_name = data_[1][0]
    F_text = data_[1][1]

    img_name = F_name.split(' ')[-1]
    h5f_name = '_'.join(F_name.split(' '))
    img_feature = np.array(h5f[h5f_name])
    img_path = path + img_name + '.png'
    
    img_feature = paddle.to_tensor(img_feature)
    results = model.generate(img_feature,beam_size=Beam_Size)

    #print('预测结果:',results[Beam_Size-1])
    #print('正确结果:',F_text)

    #img = Image.open(img_path).convert('RGB')
    #display(img, Image.BILINEAR)
    
    # 计算BLUE
    from nltk.translate.bleu_score import sentence_bleu
    reference = [F_text.split(' ')]
    candidate = results[Beam_Size-1].split(' ')
    score = sentence_bleu(reference,candidate)
    scores.append(score)

print('预测结果:',results[Beam_Size-1])
print('正确结果:',F_text)
print('BLEU:',np.mean(scores))
img = Image.open(img_path).convert('RGB')
)
    img_path = path + img_name + '.png'
    
    img_feature = paddle.to_tensor(img_feature)
    results = model.generate(img_feature,beam_size=Beam_Size)

    #print('预测结果:',results[Beam_Size-1])
    #print('正确结果:',F_text)

    #img = Image.open(img_path).convert('RGB')
    #display(img, Image.BILINEAR)
    
    # 计算BLUE
    from nltk.translate.bleu_score import sentence_bleu
    reference = [F_text.split(' ')]
    candidate = results[Beam_Size-1].split(' ')
    score = sentence_bleu(reference,candidate)
    scores.append(score)

print('预测结果:',results[Beam_Size-1])
print('正确结果:',F_text)
print('BLEU:',np.mean(scores))
img = Image.open(img_path).convert('RGB')
display(img, Image.BILINEAR)

🎖? 8 项目总结

  • 项目主要使用CNN+RNN的形式对CT影像报告的生成进行演示。

  • 由于BeamSearch的部分代码有小bug,目前使用的实际上是最大概率 已修正,可以正常传入Beam Size参数

  • 该项目是ImageCaption任务在医疗文本领域的简单实现,

  • 本项目所有代码及数据均以notebook呈现,简单易懂。

  • 本项目使用BLUE进行效果评价


特别注意:该项目灵感来自《深度学习框架Pytorch入门与实践》第十章内容。


有任何问题,欢迎评论区留言交流。
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-26 16:54:39  更:2022-06-26 16:54:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 2:38:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码