在上一篇文章中,我们已经讲了整个线性回归的数学原理,详细请见 DJL-Java开发者动手学深度学习之线性回归。在这里,我们就用DJL来实现利用线性回归来预测房价。 为了简单起见,我们利用生成数据加随机噪声的方式,来降低我们的学习成本。
生成数据
public static DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
NDArray y = X.dot(w).add(b);
y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32));
return new DataPoints(X, y);
}
定义模型
我们可以 DJL 预定义的 Block 。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。我们首先定义一个模型变量net ,它是一个 SequentialBlock 类的实例。 SequentialBlock 类为串联在一起的多个层定义了一个容器。当给定输入数据, SequentialBlock 实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,依此类推。
Model model = Model.newInstance("lin-reg");
SequentialBlock net = new SequentialBlock();
Linear linearBlock = Linear.builder().optBias(true).setUnits(1).build();
net.add(linearBlock);
model.setBlock(net);
定义损失函数
在 DJL 中,抽象类 Loss 定义了损失函数的接口。在这里我们将使用平方损失 L2Loss 。
Loss l2loss = Loss.l2Loss();
定义梯度下降
在这里,我们使用随机梯度下降
Tracker lrt = Tracker.fixed(0.03f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
初始化训练器
DefaultTrainingConfig config = new DefaultTrainingConfig(l2loss)
.optOptimizer(sgd)
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
设置运行性能指标
Metrics metrics = new Metrics();
trainer.setMetrics(metrics);
初始化模型参数
int batchSize = 10;
trainer.initialize(new Shape(batchSize, 2));
加载数据并开始训练
public static ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {
return new ArrayDataset.Builder()
.setData(features)
.optLabels(labels)
.setSampling(batchSize, shuffle)
.build();
}
ArrayDataset dataset = loadArray(features, labels, batchSize, false);
int numEpochs = 2;
for (int epoch = 1; epoch <= numEpochs; epoch++) {
System.out.printf("Epoch %d\n", epoch);
for (Batch batch : trainer.iterateDataset(dataset)) {
EasyTrain.trainBatch(trainer, batch);
trainer.step();
batch.close();
}
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
}
预测
Block layer = model.getBlock();
ParameterList params = layer.getParameters();
NDArray wParam = params.valueAt(0).getArray();
NDArray bParam = params.valueAt(1).getArray();
float[] w = trueW.sub(wParam.reshape(trueW.getShape())).toFloatArray();
System.out.printf("Error in estimating w: [%f %f]\n", w[0], w[1]);
System.out.println(String.format("Error in estimating b: %f\n", trueB - bParam.getFloat()));
预测结果
Error in estimating w: [0.008057 -0.013592]
Error in estimating b: 0.013024
保存模型
到目前为止,我们已经完成整个模型的训练,接下来,我们可以保存训练好的模型,以便后续使用。
Path modelDir = Paths.get("./models/lin-reg");
Files.createDirectories(modelDir);
model.setProperty("Epoch", Integer.toString(numEpochs));
model.save(modelDir, "lin-reg");
关注公众号,解锁更多深度学习知识
|