IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> tensorflow2的utils.Sequence -> 正文阅读

[Python知识库]tensorflow2的utils.Sequence

假设文件是这样的

images和labels里面保存的都是.npy数组

images里面的一个数据的shape=[128,128,16,1],labels里面的一个数据的shape=[128,128,16,2],因为是二分类语义分割

data_loader.py

from tensorflow.keras.utils import Sequence
import numpy as np
import math


# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.


class seg3D_Sequence(Sequence):
    def __init__(self,file_name_list, batch_size,image_path="data/images/",
                                                label_path="data/labels/"):
        self.file_name_list = file_name_list
        self.batch_size = batch_size
        self.image_path = image_path
        self.label_path = label_path        


    def __len__(self):
        return math.ceil(len(self.file_name_list) / self.batch_size)

    def __getitem__(self, idx):
        self.x, self.y = [self.image_path+file_name for file_name in self.file_name_list], \
            [self.label_path+file_name for file_name in self.file_name_list]

        # print(self.x)
        # print(('-'*60).center(60))
        # print(self.y)

        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        x_re = [np.load(file_name) for file_name in batch_x]
        y_re = [np.load(file_name) for file_name in batch_y]

        return np.array(x_re),np.array(y_re)
    
    def on_epoch_end(self):
        np.random.shuffle(self.file_name_list)

?train.py

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from utils.data_loader import seg3D_Sequence
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Conv3D,Input
import numpy as np


if __name__ == "__main__":
    print('-'*60)
    num_classes = 2
    batch_size = 7
    train_val_split = 0.2
    image_path = "data/images/"
    
    file_name_list = os.listdir(image_path)
    train_name_list = file_name_list[:int(len(file_name_list)*0.8)]
    val_name_list = file_name_list[int(len(file_name_list)*0.8):]
    # print(len(file_name_list)) # 360

    train_data_loader = seg3D_Sequence(train_name_list,batch_size)
    val_data_loader = seg3D_Sequence(val_name_list,batch_size)
    # x,y = data_loader[90]
    # print(x.shape,y.shape)

    model = Sequential()
    model.add(Conv3D(num_classes,1,activation='sigmoid'))
    inputs = Input(shape=[128,128,16,1])
    outputs = model(inputs)
    print(outputs.shape)

    model = Model(inputs,outputs,name='test3d')
    model.summary()

    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']
                  )

    model.fit(train_data_loader,
                epochs=3,
                validation_data=val_data_loader)


  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-17 22:06:24  更:2022-03-17 22:07:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 13:33:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计