当前位置:网站首页>【pointNet】基于pointNet的三维点云目标分类识别matlab仿真
【pointNet】基于pointNet的三维点云目标分类识别matlab仿真
2022-06-24 06:41:00 【fpga和matlab】
1.软件版本
matlab2021a
2.系统概述
这里,采用的pointnet网络结构如下图所示:
在整体网络结构中,
首先进行set abstraction,这一部分主要即对点云中的点进行局部划分,提取整体特征,如图可见,在set abstraction中,主要有Sampling layer、Grouping layer、以及PointNet layer三层构成,sampling layer即完成提取中心点工作,采用fps算法,而在grouping中,即完成group操作,采用mrg或msg方法,最后对于提取出得点,使用pointnet进行特征提取。在msg中,第一层set abstraction取中心点512个,半径分别为0.1、0.2、0.4,每个圈内的最大点数为16,32,128。
Sampling layer
采样层在输入点云中选择一系列点,由此定义出局部区域的中心。采样算法使用迭代最远点采样方法 iterative farthest point sampling(FPS)。先随机选择一个点,然后再选择离这个点最远的点作为起点,再继续迭代,直到选出需要的个数为止相比随机采样,能更完整得通过区域中心点采样到全局点云
Grouping layer
目的是要构建局部区域,进而提取特征。思想就是利用临近点,并且论文中使用的是neighborhood ball,而不是KNN,是因为可以保证有一个fixed region scale,主要的指标还是距离distance。
Pointnet layer
在如何对点云进行局部特征提取的问题上,利用原有的Pointnet就可以很好的提取点云的特征,由此在Pointnet++中,原先的Pointnet网络就成为了Pointnet++网络中的子网络,层级迭代提取特征。
3.部分核心程序
clc;
clear;
close all;
warning off;
addpath(genpath(pwd));
rng('default')
%****************************************************************************
%更多关于matlab和fpga的搜索“fpga和matlab”的CSDN博客:
%matlab/FPGA项目开发合作
%https://blog.csdn.net/ccsss22?type=blog
%****************************************************************************
dsTrain = PtCloudClassificationDatastore('train');
dsVal = PtCloudClassificationDatastore('test');
ptCloud = pcread('Chair.ply');
label = 'Chair';
figure;pcshow(ptCloud)
xlabel("X");ylabel("Y");zlabel("Z");title(label)
dsLabelCounts = transform(dsTrain,@(data){data{2} data{1}.Count});
labelCounts = readall(dsLabelCounts);
labels = vertcat(labelCounts{:,1});
counts = vertcat(labelCounts{:,2});
figure;histogram(labels);title('class distribution')
rng(0)
[G,classes] = findgroups(labels);
numObservations = splitapply(@numel,labels,G);
desiredNumObservationsPerClass = max(numObservations);
filesOverSample=[];
for i=1:numel(classes)
if i==1
targetFiles = {dsTrain.Files{1:numObservations(i)}};
else
targetFiles = {dsTrain.Files{numObservations(i-1)+1:sum(numObservations(1:i))}};
end
% Randomly replicate the point clouds belonging to the infrequent classes
files = randReplicateFiles(targetFiles,desiredNumObservationsPerClass);
filesOverSample = vertcat(filesOverSample,files');
end
dsTrain.Files=filesOverSample;
dsTrain.Files = dsTrain.Files(randperm(length(dsTrain.Files)));
dsTrain.MiniBatchSize = 32;
dsVal.MiniBatchSize = dsTrain.MiniBatchSize;
dsTrain = transform(dsTrain,@augmentPointCloud);
data = preview(dsTrain);
ptCloud = data{1,1};
label = data{1,2};
figure;pcshow(ptCloud.Location,[0 0 1],"MarkerSize",40,"VerticalAxisDir","down")
xlabel("X");ylabel("Y");zlabel("Z");title(label)
minPointCount = splitapply(@min,counts,G);
maxPointCount = splitapply(@max,counts,G);
meanPointCount = splitapply(@(x)round(mean(x)),counts,G);
stats = table(classes,numObservations,minPointCount,maxPointCount,meanPointCount)
numPoints = 1000;
dsTrain = transform(dsTrain,@(data)selectPoints(data,numPoints));
dsVal = transform(dsVal,@(data)selectPoints(data,numPoints));
dsTrain = transform(dsTrain,@preprocessPointCloud);
dsVal = transform(dsVal,@preprocessPointCloud);
data = preview(dsTrain);
figure;pcshow(data{1,1},[0 0 1],"MarkerSize",40,"VerticalAxisDir","down");
xlabel("X");ylabel("Y");zlabel("Z");title(data{1,2})
inputChannelSize = 3;
hiddenChannelSize1 = [64,128];
hiddenChannelSize2 = 256;
[parameters.InputTransform, state.InputTransform] = initializeTransform(inputChannelSize,hiddenChannelSize1,hiddenChannelSize2);
inputChannelSize = 3;
hiddenChannelSize = [64 64];
[parameters.SharedMLP1,state.SharedMLP1] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
inputChannelSize = 64;
hiddenChannelSize1 = [64,128];
hiddenChannelSize2 = 256;
[parameters.FeatureTransform, state.FeatureTransform] = initializeTransform(inputChannelSize,hiddenChannelSize,hiddenChannelSize2);
inputChannelSize = 64;
hiddenChannelSize = 64;
[parameters.SharedMLP2,state.SharedMLP2] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
inputChannelSize = 64;
hiddenChannelSize = [512,256];
numClasses = numel(classes);
[parameters.ClassificationMLP, state.ClassificationMLP] = initializeClassificationMLP(inputChannelSize,hiddenChannelSize,numClasses);
numEpochs = 60;
learnRate = 0.001;
l2Regularization = 0.1;
learnRateDropPeriod = 15;
learnRateDropFactor = 0.5;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;
avgGradients = [];
avgSquaredGradients = [];
[lossPlotter, trainAccPlotter,valAccPlotter] = initializeTrainingProgressPlot;
% Number of classes
numClasses = numel(classes);
% Initialize the iterations
iteration = 0;
% To calculate the time for training
start = tic;
% Loop over the epochs
for epoch = 1:numEpochs
% Reset training and validation datastores.
reset(dsTrain);
reset(dsVal);
% Iterate through data set.
while hasdata(dsTrain) % if no data to read, exit the loop to start the next epoch
iteration = iteration + 1;
% Read data.
data = read(dsTrain);
% Create batch.
[XTrain,YTrain] = batchData(data,classes);
% Evaluate the model gradients and loss using dlfeval and the
% modelGradients function.
[gradients, loss, state, acc] = dlfeval(@modelGradients,XTrain,YTrain,parameters,state);
% L2 regularization.
gradients = dlupdate(@(g,p) g + l2Regularization*p,gradients,parameters);
% Update the network parameters using the Adam optimizer.
[parameters, avgGradients, avgSquaredGradients] = adamupdate(parameters, gradients, ...
avgGradients, avgSquaredGradients, iteration,learnRate,gradientDecayFactor, squaredGradientDecayFactor);
% Update the training progress.
D = duration(0,0,toc(start),"Format","hh:mm:ss");
title(lossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D))
addpoints(lossPlotter,iteration,double(gather(extractdata(loss))))
addpoints(trainAccPlotter,iteration,acc);
drawnow
end
% Create confusion matrix
cmat = sparse(numClasses,numClasses);
% Classify the validation data to monitor the tranining process
while hasdata(dsVal)
data = read(dsVal); % Get the next batch of data.
[XVal,YVal] = batchData(data,classes);% Create batch.
% Compute label predictions.
isTrainingVal = 0; %Set at zero for validation data
YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
% Choose prediction with highest score as the class label for
% XTest.
[~,YValLabel] = max(YVal,[],1);
[~,YPredLabel] = max(YPred,[],1);
cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);% Update the confusion matrix
end
% Update training progress plot with average classification accuracy.
acc = sum(diag(cmat))./sum(cmat,"all");
addpoints(valAccPlotter,iteration,acc);
% Update the learning rate
if mod(epoch,learnRateDropPeriod) == 0
learnRate = learnRate * learnRateDropFactor;
end
reset(dsTrain); % Reset the training data since all the training data were already read
% Shuffle the data at every epoch
dsTrain.UnderlyingDatastore.Files = dsTrain.UnderlyingDatastore.Files(randperm(length(dsTrain.UnderlyingDatastore.Files)));
reset(dsVal);
end
cmat = sparse(numClasses,numClasses); % Prepare sparse-double variable to do like zeros(2,2)
reset(dsVal); % Reset the validation data
data = readall(dsVal); % Read all validation data
[XVal,YVal] = batchData(data,classes); % Create batch.
% Classify the validation data using the helper function pointnetClassifier
YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
% Choose prediction with highest score as the class label for
% XTest.
[~,YValLabel] = max(YVal,[],1);
[~,YPredLabel] = max(YPred,[],1);
% Collect confusion metrics.
cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
figure;chart = confusionchart(cmat,classes);
acc = sum(diag(cmat))./sum(cmat,"all")
4.仿真结论




5.参考文献
[1][1] Qi C R , Su H , Mo K , et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.资源同名下载
边栏推荐
- High energy ahead: Figure 18 shows you how to use the waterfall chart to visually reflect data changes
- What is the role of domain name websites? How to query domain name websites
- Stop looking! The most complete data analysis strategy of the whole network is here
- 虚拟文件系统
- In the middle of the year, I have prepared a small number of automated interview questions. Welcome to the self-test
- MAUI使用Masa blazor组件库
- sql join的使用
- How do I check the IP address? What is an IP address
- Introduction to game design and development - layered quaternion - dynamic layer
- Go breakpoint continuation
猜你喜欢

文件系统笔记

Rockscache schematic diagram of cache operation

Nine unique skills of Huawei cloud low latency Technology

In the middle of the year, I have prepared a small number of automated interview questions. Welcome to the self-test
![Jumping game ii[greedy practice]](/img/e4/f59bb1f5137495ea357462100e2b38.png)
Jumping game ii[greedy practice]

Oracle SQL comprehensive application exercises

You have a chance, here is a stage

Database stored procedure begin end

数据同步工具 DataX 已经正式支持读写 TDengine

JVM調試工具-Arthas
随机推荐
Another double win! Tencent's three security achievements were selected into the 2021 wechat independent innovation achievements recommendation manual
RealNetworks vs. 微软:早期流媒体行业之争
内网学习笔记(4)
Jumping game ii[greedy practice]
How do I check the IP address? What is an IP address
GPU frequency of zhanrui chip
JVM debugging tool -jps
展锐芯片之GPU频率
What is the OSI seven layer model? What is the role of each layer?
0 foundation a literature club low code development member management applet (6)
Internet cafe management system and database
Programmers use personalized Wallpapers
华为云低时延技术的九大绝招
Localized operation on cloud, the sea going experience of kilimall, the largest e-commerce platform in East Africa
leetcode:85. Max rectangle
.NET7之MiniAPI(特别篇) :Preview5优化了JWT验证(上)
如何低成本构建一个APP
Why use lock [readonly] object? Why not lock (this)?
The P2V and V2V software starwind converter is really easy to use
JVM调试工具-jps