IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Unet_text -> 正文阅读

[人工智能]Unet_text

def Unet(input_shape=(256,256,3), num_classes=21):
? ? inputs = Input(input_shape)
? ? feat1, feat2, feat3, feat4, feat5 = VGG16(inputs)?
? ? channels = [64, 128, 256, 512]

? ? # 32, 32, 512 -> 64, 64, 512
? ? P5_up = UpSampling2D(size=(2, 2))(feat5)
? ? # 64, 64, 512 + 64, 64, 512 -> 64, 64, 1024
? ? P4 = Concatenate(axis=3)([feat4, P5_up])
? ? # 64, 64, 1024 -> 64, 64, 512
? ? P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P4)
? ? P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P4)

? ? # 64, 64, 512 -> 128, 128, 512
? ? P4_up = UpSampling2D(size=(2, 2))(P4)
? ? # 128, 128, 256 + 128, 128, 512 -> 128, 128, 768
? ? P3 = Concatenate(axis=3)([feat3, P4_up])
? ? # 128, 128, 768 -> 128, 128, 256
? ? P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P3)
? ? P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P3)

? ? # 128, 128, 256 -> 256, 256, 256
? ? P3_up = UpSampling2D(size=(2, 2))(P3)
? ? # 256, 256, 256 + 256, 256, 128 -> 256, 256, 384
? ? P2 = Concatenate(axis=3)([feat2, P3_up])
? ? # 256, 256, 384 -> 256, 256, 128
? ? P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P2)
? ? P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P2)

? ? # 256, 256, 128 -> 512, 512, 128
? ? P2_up = UpSampling2D(size=(2, 2))(P2)
? ? # 512, 512, 128 + 512, 512, 64 -> 512, 512, 192
? ? P1 = Concatenate(axis=3)([feat1, P2_up])
? ? # 512, 512, 192 -> 512, 512, 64
? ? P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P1)
? ? P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = RandomNormal(stddev=0.02))(P1)

? ? # 512, 512, 64 -> 512, 512, num_classes
? ? P1 = Conv2D(num_classes, 1, activation="softmax")(P1)

? ? model = Model(inputs=inputs, outputs=P1)
? ? return model

# step0:参数配置
# dataset_path = r"G:\deep_learning_data\EG_dataset\voc_format" ?# local?
dataset_path = os.path.join(BASE_DIR, "..", "data", "dataset", "voc_format") ?# ?linux?
model_path = os.path.join(BASE_DIR, "data", "model_data", "unet_voc.h5")

max_epoch = 100 ?# 总迭代轮
Batch_size = 1 ??
inputs_size = [224, 224, 3]
num_classes = 2 ?# 模型输出通道数, 这包含背景类别数,本例中为 1+1=2 ? #?
lr = 1e-4
decay_rate = 0.95 ?# 指数衰减参数,每个epoch之后,学习率衰减率

import datetime
curr_time = datetime.datetime.now()
time_str = datetime.datetime.strftime(curr_time, '%Y_%m_%d_%H_%M_%S')
loss_history = LossHistory("logs/", time_str)
log_dir = os.path.join(BASE_DIR, "logs", "loss_" + time_str)
print("日志文件夹位于:{}".format(log_dir))

# step1:数据集创建
with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"), "r") as f:
? ? train_lines = f.readlines()
with open(os.path.join(dataset_path, "ImageSets/Segmentation/val.txt"), "r") as f:
? ? val_lines = f.readlines()

epoch_size = len(train_lines) // Batch_size ?# 计算一个epoch有几个iteration
epoch_size_val = len(val_lines) // Batch_size

# ?利用生成器创建dataset
gen = Generator(Batch_size, train_lines, inputs_size, num_classes, dataset_path)
gen = tf.data.Dataset.from_generator(partial(gen.generate, random_data=True), (tf.float32, tf.float32))
gen = gen.shuffle(buffer_size=Batch_size).prefetch(buffer_size=Batch_size)

gen_val = Generator(Batch_size, val_lines, inputs_size, num_classes, dataset_path)
gen_val = tf.data.Dataset.from_generator(partial(gen_val.generate, random_data=False), (tf.float32, tf.float32))
gen_val = gen_val.shuffle(buffer_size=Batch_size).prefetch(buffer_size=Batch_size)

if epoch_size == 0 or epoch_size_val == 0:
? ? raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

# step2: 创建model
model = Unet(inputs_size, num_classes) ?# 需要传入输入大小和最终输出通道数
model.load_weights(model_path, by_name=True, skip_mismatch=True) ?# 加载预训练模型

# step3:创建loss及优化器
loss = CE()
lr_schedule = ExponentialDecay(initial_learning_rate=lr, decay_steps=epoch_size,decay_rate=decay_rate, staircase=True)
optimizer = Adam(learning_rate=lr_schedule)

# step4:迭代训练
for epoch in range(max_epoch):
? ? fit_one_epoch(model, loss, optimizer, epoch, epoch_size, epoch_size_val, gen, gen_val, max_epoch, get_train_step_fn())
? ? path_model = os.path.join(log_dir, "model_weight_{}.h5".format(time_str))
? ? model.save_weights(path_model)
? ? print("Epoch:{}, model save at :{}".format(epoch, path_model))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-03 11:53:35  更:2021-09-03 11:55:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:53:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码