我是靠谱客的博主 任性洋葱,这篇文章主要介绍pointnet train函数第二十七句 for epoch in range(MAX_EPOCH):,现在分享给大家,希望可以做个参考。

复制代码
1
2
3
4
5
6
7
8
9
10
11
for epoch in range(MAX_EPOCH): log_string('**** EPOCH %03d ****' % (epoch)) sys.stdout.flush() train_one_epoch(sess, ops, train_writer) eval_one_epoch(sess, ops, test_writer) # Save the variables to disk. if epoch % 10 == 0: save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) log_string("Model saved in file: %s" % save_path)

当前工程cls的MAX_EPOCH设置为250,当前的epoch是为了增加样本数量,因为样本有限,所以需要每个epoch打乱一次训练样本,以此来增加训练样本的总数

复制代码
1
train_one_epoch(sess, ops, train_writer)

这句则是具体train过程函数,具体实现如下

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) np.random.shuffle(train_file_idxs) for fn in range(len(TRAIN_FILES)): log_string('----' + str(fn) + '-----') print(TRAIN_FILES[train_file_idxs[fn]]) current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) print("current_data shape") print(current_data.shape) print("current_label shape") print(current_label.shape) current_data = current_data[:,0:NUM_POINT,:] current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) current_label = np.squeeze(current_label) file_size = current_data.shape[0] print("current_data.shape[0]:") print(current_data.shape[0]) num_batches = file_size // BATCH_SIZE print("num_batches,BATCH_SIZE") print(num_batches) print(BATCH_SIZE) total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx+1) * BATCH_SIZE # Augment batched point clouds by rotation and jittering rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) feed_dict = {ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training,} summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val log_string('mean loss: %f' % (loss_sum / float(num_batches))) log_string('accuracy: %f' % (total_correct / float(total_seen)))

 

这里第一句is_training是用来设置is_training_pl tensor

第二句,TRAIN_FILES可以看看定义里面,是从train_files.txt里面读取point数据的路径,内容如下

复制代码
1
2
3
4
5
data/modelnet40_ply_hdf5_2048/ply_data_train0.h5 data/modelnet40_ply_hdf5_2048/ply_data_train1.h5 data/modelnet40_ply_hdf5_2048/ply_data_train2.h5 data/modelnet40_ply_hdf5_2048/ply_data_train3.h5 data/modelnet40_ply_hdf5_2048/ply_data_train4.h5

 调用了

复制代码
1
provider.getDataFiles,产生一个array里面放的是上面h5文件名,第二句则根据array的长度产生对应的一个index的array

第三句是将这个array顺序打乱,从而在每次epoch有不同的索引序列,在第四句for循环中读取data的时候point样本不同,达到增加训练数据的目的

然后看for循环内部,第一步是读取point data以及label data,参考pointnet provider.loadDataFile读取之后shape为三维batchsize,pointnum,xyz,的current_data因为每个h5文件中的point的pointnum不一定都是1024因此,需要对pointnum这一维进行处理,然后继续对点云顺序打乱顺序继续增加样本多样性

复制代码
1
current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))

参考pointnet shuffle_data(data, labels)

其中"_"为index np array点云处理过程中,这里的index特别重要,所以要及时记录

复制代码
1
current_label = np.squeeze(current_label)

继续把current_label 中维度为1的数据去除掉

复制代码
1
2
3
file_size = current_data.shape[0] num_batches = file_size // BATCH_SIZE

首先获取h5文件中的file_size即一个文件中有多少个point模型。因为我们的网络处理单位是batchsize*pointnum为一次,因此需要计算一个文件需要多少次bacthsize计算才能训练完毕一个文件

即num_batches

复制代码
1
2
3
total_correct = 0 total_seen = 0 loss_sum = 0

分别为总的分类准确的个数

总的训练模型个数

总的损失值

获取到数据之后根据num_batches进行loop每个loop一个batchsize的模型数进入下面的循环

复制代码
1
for batch_idx in range(num_batches):

因为current_data里面第一维是filesize,是h5文件中总的point点云模型的array,每次读取一个batchsie个,当前设置的是32个,因此需要每次循环更新起始index以及终止index,取出一个batchsize的point模型数据。

复制代码
1
2
start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx+1) * BATCH_SIZE

取出数据之后perovider进行处理,

复制代码
1
2
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data)

参考pointnet provider.rotate_point_cloud provider.jitter_point_cloud,作用是对point cloud进行旋转平移,旋转角度以及平移距离为随机的,产生更多的训练样本

复制代码
1
2
3
feed_dict = {ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training,}

此处的代码是将我们之前的placeholder产生的rensor进行feedict才能在gragh中进行运算

复制代码
1
2
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)

这一句即是运行我们前二十六句建立起来的由各种tensor组成的op链接起来的gragh图,运行完即训练完了一次,返回summary可视化数据,step训练次数,_ 点云索引,loss数据,pred正向数据

复制代码
1
2
3
4
5
pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val
复制代码
1
2
log_string('mean loss: %f' % (loss_sum / float(num_batches))) log_string('accuracy: %f' % (total_correct / float(total_seen)))

 

训练完一个文件内的数据,计算一次平均loss以及精确度

train完一次,接着进行test数据进行预测并且得到预测准确率参考pointnet def eval_one_epoch(sess, ops, test_writer)

至此分类训练代码解读完毕后续会持续更新更正,然后用pytorch+open3d实现一遍。敬请期待

 

 

最后

以上就是任性洋葱最近收集整理的关于pointnet train函数第二十七句 for epoch in range(MAX_EPOCH):的全部内容,更多相关pointnet内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(67)

评论列表共有 0 条评论

立即
投稿
返回
顶部