问题背景
目前深度学习多使用python实现。不过想要配置好一个python的深度学习环境有时却并不轻松,常常因为各个第三方库版本兼容性问题而失败。相比之下,matlab仅需一次安装简化了不少工作。这几年matlab的深度学习工具箱也是发展迅速。但我发现matlab的相关资料却比较少。因此,我探索了下如何用matlab搭建一个BiLSTM用于时序遥感数据的分类。
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79clc,clear; %% Load the training data % trainX: an array with shape of (n, c, t). n represents the number % of training samples, c is the number of features, t is the % length of time sequence. % trainY: an array with shape of (n,). n represents the number of training % samples. rootDir = 'root_dir'; trainingData = importdata(fullfile(rootDir,'train.mat')); trainX = trainingData.trainx; trainY = trainingData.trainy+1; Xtrain = cell({}); for i = 1:size(trainX,1) Xtrain{i,1} = squeeze(trainX(i,:,:)); end Ytrain = categorical(trainY'); %% Load the validation data valData = importdata(fullfile(rootDir,'test.mat')); valX = valData.testx; valY = valData.testy+1; Yval = categorical(valY'); Xval = cell({}); for i = 1:size(valX,1) Xval{i,1} = squeeze(valX(i,:,:)); end valDataSet = cell({Xval,Yval}); %% Create bilstm model % numFeatures: The number of expected features in input data % numHiddens: The number of features in the hidden state % numClasses: The number of classess numFeatures = 8; numHiddens = 256; numClasses = 5; netLayers = [ sequenceInputLayer(numFeatures,"Name","input") bilstmLayer(numHiddens,"Name","bilstm_1",'OutputMode','last') bilstmLayer(numHiddens,"Name","bilstm_2",'OutputMode','last') dropoutLayer(0.5,"Name","dropout") flattenLayer("Name","flatten") fullyConnectedLayer(numClasses,"Name","fc") softmaxLayer("Name","softmax") classificationLayer("Name","classification")]; %% Set the hyper parameters for unet training options = trainingOptions('adam', ... 'InitialLearnRate',1e-4, ... 'Plots','training-progress',... 'MaxEpochs',60, ... 'MiniBatchSize',128,... 'VerboseFrequency',1,... 'ExecutionEnvironment', 'auto',... 'Shuffle','every-epoch',... 'ValidationData',valDataSet,... 'ValidationFrequency',1,... 'WorkerLoad',4,... 'CheckPointPath',rootDir); % start training net = trainNetwork(Xtrain,Ytrain, netLayers, options); %% Save and load model save('bilstm.mat','net'); bilstm = importdata('bilstm.mat'); %% Accuracy assessment pred = classify(bilstm, Xval); [confusionMatrix,order] = confusionmat(categorical(valY),pred); cm = confusionchart(confusionMatrix); % caculate user accuracy and mapping accuracy confusionMatrix = [confusionMatrix, zeros(size(confusionMatrix,1),1)]; confusionMatrix = [confusionMatrix; zeros(1,size(confusionMatrix,2))]; confusionMatrix(1:end-1,end) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))... ./sum(confusionMatrix(1:end-1,1:end-1),2)'; confusionMatrix(end,1:end-1) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))... ./sum(confusionMatrix(1:end-1,1:end-1),1); confusionMatrix(end,end) = sum(confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses)))... ./sum(sum(confusionMatrix(1:end-1,1:end-1))); mappingAccuracy = confusionMatrix(end,1:end-1); userAccuracy = confusionMatrix(1:end-1,end); totalAccuracy = confusionMatrix(end,end);
分类精度还不错,达到了88.9%
最后
以上就是疯狂冬瓜最近收集整理的关于基于Matlab的BiLSTM实现问题背景的全部内容,更多相关基于Matlab内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复