示例
$(async () => {
$('#fi').on('submit',()=>{
console.log($('#fi a').val());
console.log($('#fi b').val());
console.log($('#fi c').val());
console.log($('#fi d').val());
if(window.predict){
window.predict({
a:$('#fi #a').val(),
b:$('#fi #b').val(),
c:$('#fi #c').val(),
d:$('#fi #d').val()
});
}else{
alert('模型正在训练');
}
return false;
});
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
// xTrain.print();
// yTrain.print();
// xTest.print();
// yTest.print();
//定义连续模型sequential
const model = tf.sequential();
//设置层,全连接层tf.layer.dense
model.add(tf.layers.dense({
units: 10,
inputShape: [xTrain.shape[1]],
activation: 'sigmoid'
}));
//分为三类
model.add(tf.layers.dense({
units: 3,
activation: 'softmax'//激活函数,处理非线性变化,适用于多分类
// activation: 'sigmoid'//激活函数,处理非线性变化,适用于二分类
}));
//设置优化器
model.compile({
loss: 'categoricalCrossentropy',//交叉熵,适用于多分类
optimizer: tf.train.adam(0.1),//优化器
metrics: ['accuracy']//准确度
});
//训练模型
await model.fit(xTrain, yTrain, {
epochs: 100,
validationData: [xTest, yTest],
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss', 'acc', 'val_acc'],
{ callbacks: ['onEpochEnd'] }
)
});
window.predict=(form)=>{
const input=tf.tensor([[
form.a*1,
form.b*1,
form.c*1,
form.d*1
]]);
const pred=model.predict(input);
alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);//第二位最大值
}
});
html部分
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Document</title>
<script src="js/tensorflow/tfjs.js"></script>
<script src="js/tensorflow/tfjs-vis.js"></script>
<script src="js/jquery/jquery.js"></script>
<script src="js/iris-data/data.js"></script>
</head>
<body>
<div>iris</div>
<form id="fi">
<label for="a">
花萼长度:<input type="text" name="a" id="a">
</label><br />
<label for="b">
花萼宽度:<input type="text" name="b" id="b">
</label><br />
<label for="c">
花瓣长度:<input type="text" name="c" id="c">
</label><br />
<label for="d">
花瓣宽度:<input type="text" name="d" id="d">
</label><br />
<button>提交</button>
</form>
</body>
<script src="js/iris.js"></script>
</html>
|