提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
介绍MATLAB深度学习工具使用方法,并提供了两个应用实例,第一个是手写数字识别,第二个是用LSTM预测海洋温度变化。文中给出了完整的代码。
提示:以下是本篇文章正文内容,下面案例可供参考
一、打开方式
1)在命令框输入deepNetworkDesigner;(推荐使用) 2)如下方式,通过点击图标的形式进入: a.点击主页的APP:
b. 点击所标记的下拉三角箭头
c.输入deep即会出现Deep Network Designer。
- 实例1 手写数字识别
二、 功能和构成
(1)Layer库 网络net是由层Layer连接而成的。 Layer库是整个神经网络的基本组成单元,非常重要,所有复杂的神经网络均通过这些基本单元进行搭建。 Layer库一共有9种,不同的种类通过颜色进行区分,可以直接拖动到中间的设计区进行连接。 1)INPUT输入层 2)CONVOLUTION AND FULLY CONNECTED卷积和全连接层 3)SEQUENCE序列层 4)ACTIVATION激活层 5)NORMALIZATION AND UTILITY归一化层 6)POOLING池化层 7)COMBINATION组合层 8)OBJECT DETECTION目标检测层 9)OUTPUT输出层 (2)设计区 设计区包含了三个卡片,Designer,Data,Training。 其中在Designer中进行Layer组件的布局; Data中导入训练数据; Training中进行网络训练。 (3)Analyze功能 可以对设计的网络进行评估,包含warnings和errors。另外在分析的ANALYSIS RESULT区域可以看到各层的名称Name、类型Type、维度Activations和可学习的参数Learnables。 目前的情况是,可学习的参数只有卷积和全连接层中组件的Weights和Bias。其他各组件均无待学习参数,只有一些需要配置的超参数等。 (4)Export功能 export功能可以把设计的网络导出为一个对象,加载到Workspace中;或者导出生成代码到Live Editor中便于进一步处理。
三、应用实例
1.手写数字识别
代码如下(示例):
% 加载数据集
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
% 展示数据集
figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
subplot(4,5,i);
imshow(imds.Files{perm(i)});
end
% 划分数据集和测试集出来
numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');
% 使用搭建的神经网络
layers = [ ...
imageInputLayer([28 28 1])
convolution2dLayer(5,20)
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% 配置训练参数
options = trainingOptions('sgdm', ...
'MaxEpochs',20,...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% 训练神经网络
ile = gpuArray(0.0001);
net = trainNetwork(imdsTrain,layers,options);
% 使用神经网络,这是分类的例子。使用classify.搭配神经网络中的最后一层classificationLayer
% 如果是回归的神经网络,则神经网络的最后一层是regressionLayer,搭配predict使用,就是替换classify为predict
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest)/numel(YTest)
测试结果 数据集下载地址:[https://download.csdn.net/download/rookiefish/85121493]
2.基于MATLAB的LSTM预测ENSO海温在全球变暖下的变化
厄尔尼诺-南方涛动(ENSO)是影响全球极端气候的重要因子,因此预估ENSO海温在全球变暖下的变化也是预估未来全球变暖下全球极端气候变化的重要因素之一。遗憾的是,最近20年来多次的全球耦合模式比较计划(CMIP)尽管对气候平均态和ENSO本身的模拟都取得了长足的进步,但是对未来增暖情形下ENSO海温异常强度的变化都存在显著的模式间差异。最近CMIP5预估的ENSO海温增强或减弱的模式个数基本相当,模式间的标准差远大于多模式集合平均的结果。因此揭示各模式预估ENSO海温振幅变化存在显著差异的核心物理过程是未来进一步改进模式、提高模式预估可信度的必要途径。 代码如下(示例):
close all;clear all;clc;
rand('seed',10); %设置随机数种子
%% I.加载数据
load enso
data_x=month';
data_y=pressure';
%% II.数据预处理
mu=mean(data_y);%计算均值
sig=std(data_y);%计算标准差
data_y=(data_y-mu)/sig;%数据归一化
%% III.数据准备
wd=5;
len=numel(data_y);%计算data_y中元素数目
wdata=[];
for i=1:1:len-wd
di=data_y(i:i+wd);
wdata=[wdata;di];
end
wdata_origin=wdata;
index_list=randperm(size(wdata,1));%整数随机排序
ind=round(0.8*length(index_list));%四舍五入
train_index=index_list(1:ind);
test_index=index_list(ind+1:end);
train_index=sort(train_index);
test_index=sort(test_index);
%% IV.划分训练集、测试集的数据和标签
dataTrain=wdata(train_index,:);
dataTest=wdata(test_index,:);
XTrain=dataTrain(:,1:end-1)';
YTrain=dataTrain(:,end)';
XTest=dataTest(:,1:end-1)';
YTest=dataTest(:,end)';
%% V.网络构建
layers=get_lstm_net(wd);
options=trainingOptions('adam',...
'MaxEpochs',1000,...
'GradientThreshold',1,...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise',...
'LearnRateDropPeriod',125,...
'LearnRateDropFactor',0.2,...
'Verbose',0,...
'Plots','training-progress');
%% VI.训练
net=trainNetwork(XTrain,YTrain,layers,options);
%% VII.测试
Xall=wdata_origin(:,1:end-1)';
Yall=wdata_origin(:,end)';
YPred=predict(net,Xall,'MiniBatchSize',1);
rmse=mean((YPred(:)-Yall(:)).^2);
%% VIII.显示
figure,
subplot(2,1,1)
plot(data_x(1:length(Yall)),Yall)
hold on;
plot(data_x(1:length(Yall)),YPred,'.-')
hold off;
legend(['Real','Predict'])
ylabel('Data')
title(sprintf('LSTM分析-RMSE=%.2f',rmse));
subplot(2,1,2)
stem(data_x(1:length(Yall)),YPred-Yall)
xlabel('Time');ylabel('Error');
title('LSTM分析-误差图');
网络结构函数
function layers=get_lstm_net(wd)
%网络架构
numFeatures=wd;
numResponses=1;
numHiddenUnits=250;
layers=[sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
dropoutLayer(0.1)
lstmLayer(2*numHiddenUnits)
dropoutLayer(0.1)
fullyConnectedLayer(numResponses)
regressionLayer];
end
预测结果
总结
以上就是今天要讲的内容,本文仅仅简单介绍了MATLAB深度学习工具的使用,并提供了两个实例。
|