IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> tf2训练模型的编写方法集合 -> 正文阅读

[Python知识库]tf2训练模型的编写方法集合

?

tf.keras.backend.clear_session()
 
model = models.Sequential()
model.add(layers.Dense(64, input_dim=64,
                kernel_regularizer=regularizers.l2(0.01), 
                activity_regularizer=regularizers.l1(0.01),
                kernel_constraint = constraints.MaxNorm(max_value=2, axis=0))) 
model.add(layers.Dense(10,
        kernel_regularizer=regularizers.l1_l2(0.01,0.01),activation = "sigmoid"))
model.compile(optimizer = "rmsprop",
        loss = "sparse_categorical_crossentropy",metrics = ["AUC"])
model.summary()

API法

model=tf.keras.Sequential()
model.add(tf.keras.layers.Dense(10,input_shape=(3,),activation='relu'))
model.add(tf.keras.layers.Dense(6,activation='relu'))

model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(1,activation='sigmoid') )

model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
? ? ? ? ? ? ? ????????????????loss='mse', # 'binary_crossentropy',
? ? ? ? ? ? ? ????????????????metrics=['acc'])

model.fit(x,y,epochs=500)

函数法

input=tf.keras.Input(shape=(28,28))
x=tf.keras.layers.Flatten()(input)
x=tf.keras.layers.Dense(32,activation='relu')(x)
x=tf.keras.layers.Dropout(0.5)(x)
x=tf.keras.layers.Dense(64,activation='relu')(x)
output=tf.keras.layers.Dense(10,activation='softmax')(x)

model=tf.keras.Model(input,output,)

model.compile(? optimizer='adam',
? ? ? ? ? ? ? loss='sparse_categorical_crossentropy', ? #连续编码用,只有一个序号;如果用? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?#one-hot独热编码(即将一个序号编码成独1多0输出)时删掉 sparse_
? ? ? ? ? ? ? metrics=['acc']?)

分步详解法

w1 = tf.Variable(tf.random.normal([2, 11]), dtype=tf.float32)
b1 = tf.Variable(tf.constant(0.01, shape=[11]))

w2 = tf.Variable(tf.random.normal([11, 1]), dtype=tf.float32)
b2 = tf.Variable(tf.constant(0.01, shape=[1]))

lr = 0.01 ? ?# 学习率
epoch = 300 ?# 循环轮数

# 训练部分
for epoch in range(epoch):
? ? for step, (x_train, y_train) in enumerate(train_db):
? ? ? ? with tf.GradientTape() as tape:? ? ? ? ?# 记录梯度信息

? ? ? ? ? ? h1 = tf.matmul(x_train, w1) + b1 ?# 记录神经网络乘加运算
? ? ? ? ? ? h1 = tf.nn.relu(h1)
? ? ? ? ? ? y = tf.matmul(h1, w2) + b2
? ? ? ? ?
? ? ? ? ? ? # 采用均方误差损失函数mse = mean(sum(y-out)^2)
? ? ? ? ? ? loss_mse = tf.reduce_mean(tf.square(y_train - y))
? ? ? ? ? ? # 添加l2正则化
? ? ? ? ? ? loss_regularization = []
? ? ? ? ? ? # tf.nn.l2_loss(w)=sum(w ** 2) / 2
? ? ? ? ? ? loss_regularization.append(tf.nn.l2_loss(w1))
? ? ? ? ? ? loss_regularization.append(tf.nn.l2_loss(w2))
? ? ? ? ? ? # 求和
? ? ? ? ? ? # 例:x=tf.constant(([1,1,1],[1,1,1]))
? ? ? ? ? ? # ? tf.reduce_sum(x)
? ? ? ? ? ? # >>>6
? ? ? ? ? ? # loss_regularization = tf.reduce_sum(tf.stack(loss_regularization))
? ? ? ? ? ? loss_regularization = tf.reduce_sum(loss_regularization)
? ? ? ? ? ? loss = loss_mse + 0.03 * loss_regularization #REGULARIZER = 0.03 ? ? ? ? ? ?

? ? ? ? # 计算loss对各个参数的梯度 ? ? ?
? ? ? ? grads = tape.gradient(loss, [w1, b1, w2, b2])

? ? ? ? # 实现梯度更新
? ? ? ? # w1 = w1 - lr * w1_grad tape.gradient是自动求导结果与[w1, b1, w2, b2] 索引为0,1,2,3?
? ? ? ? w1.assign_sub(lr * grads[0])
? ? ? ? b1.assign_sub(lr * grads[1])
? ? ? ? w2.assign_sub(lr * grads[2])
? ? ? ? b2.assign_sub(lr * grads[3])

? ? # 每20个epoch,打印loss信息
? ? if epoch % 20 == 0:
? ? ? ? print('epoch:', epoch, 'loss:', float(loss))

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-10-17 12:28:59  更:2022-10-17 12:29:14 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/26 3:43:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计