我是靠谱客的博主 疯狂冬瓜,最近开发中收集的这篇文章主要介绍基于Matlab的BiLSTM实现问题背景,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

问题背景

目前深度学习多使用python实现。不过想要配置好一个python的深度学习环境有时却并不轻松,常常因为各个第三方库版本兼容性问题而失败。相比之下,matlab仅需一次安装简化了不少工作。这几年matlab的深度学习工具箱也是发展迅速。但我发现matlab的相关资料却比较少。因此,我探索了下如何用matlab搭建一个BiLSTM用于时序遥感数据的分类。

clc,clear;
%% Load the training data
% trainX: an array with shape of (n, c, t). n represents the number
%
of training samples, c is the number of features, t is the
%
length of time sequence.
% trainY: an array with shape of (n,). n represents the number of training
%
samples.
rootDir = 'root_dir';
trainingData = importdata(fullfile(rootDir,'train.mat'));
trainX = trainingData.trainx;
trainY = trainingData.trainy+1;
Xtrain = cell({});
for i = 1:size(trainX,1)
Xtrain{i,1} = squeeze(trainX(i,:,:));
end
Ytrain = categorical(trainY');
%% Load the validation data
valData = importdata(fullfile(rootDir,'test.mat'));
valX = valData.testx;
valY = valData.testy+1;
Yval = categorical(valY');
Xval = cell({});
for i = 1:size(valX,1)
Xval{i,1} = squeeze(valX(i,:,:));
end
valDataSet = cell({Xval,Yval});
%% Create bilstm model
% numFeatures: The number of expected features in input data
% numHiddens: The number of features in the hidden state
% numClasses: The number of classess
numFeatures = 8;
numHiddens = 256;
numClasses = 5;
netLayers = [
sequenceInputLayer(numFeatures,"Name","input")
bilstmLayer(numHiddens,"Name","bilstm_1",'OutputMode','last')
bilstmLayer(numHiddens,"Name","bilstm_2",'OutputMode','last')
dropoutLayer(0.5,"Name","dropout")
flattenLayer("Name","flatten")
fullyConnectedLayer(numClasses,"Name","fc")
softmaxLayer("Name","softmax")
classificationLayer("Name","classification")];
%% Set the hyper parameters for unet training
options = trainingOptions('adam', ...
'InitialLearnRate',1e-4, ...
'Plots','training-progress',...
'MaxEpochs',60, ...
'MiniBatchSize',128,...
'VerboseFrequency',1,...
'ExecutionEnvironment', 'auto',...
'Shuffle','every-epoch',...
'ValidationData',valDataSet,...
'ValidationFrequency',1,...
'WorkerLoad',4,...
'CheckPointPath',rootDir);
% start training
net = trainNetwork(Xtrain,Ytrain, netLayers, options);
%% Save and load model
save('bilstm.mat','net');
bilstm = importdata('bilstm.mat');
%% Accuracy assessment
pred = classify(bilstm, Xval);
[confusionMatrix,order] = confusionmat(categorical(valY),pred);
cm = confusionchart(confusionMatrix);
% caculate user accuracy and mapping accuracy
confusionMatrix = [confusionMatrix, zeros(size(confusionMatrix,1),1)];
confusionMatrix = [confusionMatrix; zeros(1,size(confusionMatrix,2))];
confusionMatrix(1:end-1,end) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))...
./sum(confusionMatrix(1:end-1,1:end-1),2)';
confusionMatrix(end,1:end-1) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))...
./sum(confusionMatrix(1:end-1,1:end-1),1);
confusionMatrix(end,end) = sum(confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses)))...
./sum(sum(confusionMatrix(1:end-1,1:end-1)));
mappingAccuracy = confusionMatrix(end,1:end-1);
userAccuracy = confusionMatrix(1:end-1,end);
totalAccuracy = confusionMatrix(end,end);

分类精度还不错,达到了88.9%

 

最后

以上就是疯狂冬瓜为你收集整理的基于Matlab的BiLSTM实现问题背景的全部内容,希望文章能够帮你解决基于Matlab的BiLSTM实现问题背景所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(74)

评论列表共有 0 条评论

立即
投稿
返回
顶部