IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> DJL-Java动手学深度学习之线性回归实现 -> 正文阅读

[人工智能]DJL-Java动手学深度学习之线性回归实现

在上一篇文章中,我们已经讲了整个线性回归的数学原理,详细请见 DJL-Java开发者动手学深度学习之线性回归。在这里,我们就用DJL来实现利用线性回归来预测房价。
为了简单起见,我们利用生成数据加随机噪声的方式,来降低我们的学习成本。

生成数据


public static DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
        NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
        NDArray y = X.dot(w).add(b);
        // Add noise
        y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32));
        return new DataPoints(X, y);
    }

定义模型

我们可以 DJL 预定义的 Block。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。我们首先定义一个模型变量net,它是一个 SequentialBlock 类的实例。 SequentialBlock 类为串联在一起的多个层定义了一个容器。当给定输入数据, SequentialBlock 实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,依此类推。

//定义模式,一层神经网络
Model model = Model.newInstance("lin-reg");

SequentialBlock net = new SequentialBlock();
Linear linearBlock = Linear.builder().optBias(true).setUnits(1).build();
net.add(linearBlock);

model.setBlock(net);

定义损失函数

在 DJL 中,抽象类 Loss 定义了损失函数的接口。在这里我们将使用平方损失 L2Loss

//损失函数
Loss l2loss = Loss.l2Loss();

定义梯度下降

在这里,我们使用随机梯度下降

//随机梯度下降 sgd
Tracker lrt = Tracker.fixed(0.03f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

初始化训练器

DefaultTrainingConfig config = new DefaultTrainingConfig(l2loss)
			.optOptimizer(sgd)
			.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

设置运行性能指标

Metrics metrics = new Metrics();
trainer.setMetrics(metrics);

初始化模型参数

int batchSize = 10;
trainer.initialize(new Shape(batchSize, 2));

加载数据并开始训练

public static  ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {
	return new ArrayDataset.Builder()
				  .setData(features) // 设置训练集
				  .optLabels(labels) // 设置预测值
				  .setSampling(batchSize, shuffle) // 设置批量大小和随机抽样
				  .build();
}

ArrayDataset  dataset = loadArray(features, labels, batchSize, false);
int numEpochs = 2;

for (int epoch = 1; epoch <= numEpochs; epoch++) {
	System.out.printf("Epoch %d\n", epoch);
	// Iterate over dataset
	for (Batch batch : trainer.iterateDataset(dataset)) {
		EasyTrain.trainBatch(trainer, batch);
		trainer.step();
		batch.close();
	}
	trainer.notifyListeners(listener -> listener.onEpoch(trainer));
}

预测

//预测
Block layer = model.getBlock();
ParameterList params = layer.getParameters();
NDArray wParam = params.valueAt(0).getArray();
NDArray bParam = params.valueAt(1).getArray();

float[] w = trueW.sub(wParam.reshape(trueW.getShape())).toFloatArray();
System.out.printf("Error in estimating w: [%f %f]\n", w[0], w[1]);
System.out.println(String.format("Error in estimating b: %f\n", trueB - bParam.getFloat()));

预测结果

Error in estimating w: [0.008057 -0.013592]
Error in estimating b: 0.013024

保存模型

到目前为止,我们已经完成整个模型的训练,接下来,我们可以保存训练好的模型,以便后续使用。

//保存模型
Path modelDir = Paths.get("./models/lin-reg");
Files.createDirectories(modelDir);
model.setProperty("Epoch", Integer.toString(numEpochs));
model.save(modelDir, "lin-reg");

关注公众号,解锁更多深度学习知识

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-05 12:02:52  更:2021-12-05 12:03:14 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 0:44:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码