分享

Ransac

 mscdj 2018-02-02

随机一致性采样RANSAC是一种鲁棒的模型拟合算法,能够从有外点的数据中拟合准确的模型。


RANSAC过程中用到的参数

N-- 拟合模型所需要的最少的样本个数

K--算法的迭代次数

t--用于判断数据是否是内点

d--判定模型是否符合使用于数据集,也就是判断是否是好的模型


RANSAC算法过程

1  for K 次迭代

2     从数据中均匀随机采样N个点

3     利用采样的N个点拟合你个模型

4     for 对于除采样点外的每一个样本点

5          利用t检测样本点到模型的距离,如果小于t则认为是一致,否则认为是外点

6     end

7     如果有d或者更多的一致点,则认为拟合的模型是好的

8 end

9 使用拟合误差作为标准,选择最好的拟合模型



迭代次数的计算

假设 r = 内点个数/所有点的个数

 则:

   p0 = pow(r, N) 表示采样的N个点全为内点,也就是是一次有效采样的概率

   p1 = 1 - pow(r, N) 表示采样的N个点中至少有一个外点,即一次无效采样的概率

   p2 = pow(p1, K) 表示K次无效采样的概率

假设p表示K次采样中至少一次采样是有效采样,则有1-p = pow(p1, K), 两边取对数

则有 K = log(1- p )/log(1-p1).


 附一份来自google 的RANSAC的代码框架

[cpp] view plain copy
  1. #ifndef FVISION_RANSAC_H_  
  2. #define FVISION_RANSAC_H_  
  3.   
  4. #include <fvision/utils/random_utils.h>  
  5. #include <fvision/utils/misc.h>  
  6.   
  7. #include <vector>  
  8. #include <iostream>  
  9. #include <cassert>  
  10.   
  11. namespace fvision {  
  12.   
  13. class RANSAC_SamplesNumber {  
  14. public:  
  15.         RANSAC_SamplesNumber(int modelSampleSize) {  
  16.                 this->s = modelSampleSize;  
  17.                 this->p = 0.99;  
  18.         }  
  19.         ~RANSAC_SamplesNumber(void) {}  
  20.   
  21. public:  
  22.         long calcN(int inliersNumber, int samplesNumber) {  
  23.                 double e = 1 - (double)inliersNumber / samplesNumber;  
  24.                 //cout<<"e: "<<e<<endl;  
  25.                 if (e > 0.9) e = 0.9;  
  26.                 //cout<<"pow: "<<pow((1 - e), s)<<endl;  
  27.                 //cout<<log(1 - pow((1 - e), s))<<endl;  
  28.                 long N = (long)(log(1 - p) / log(1 - pow((1 - e), s)));  
  29.                 if (N < 0) return (long)1000000000;  
  30.                 else return N;  
  31.         }  
  32.   
  33. private:                  
  34.         int s;      //samples size for fitting a model  
  35.         double p;   //probability that at least one of the random samples if free from outliers  
  36.                     //usually 0.99  
  37. };  
  38.   
  39. //fit a model to a set of samples  
  40. template <typename M, typename S>  
  41. class GenericModelCalculator {  
  42. public:  
  43.         typedef std::vector<S> Samples;  
  44.         virtual M compute(const Samples& samples) = 0;  
  45.   
  46.         virtual ~GenericModelCalculator<M, S>() {}  
  47.   
  48.         //the model calculator may only use a subset of the samples for computing  
  49.         //default return empty for both  
  50.         virtual const std::vector<int>& getInlierIndices() const { return defaultInlierIndices; };  
  51.         virtual const std::vector<int>& getOutlierIndices() const { return defaultOutlierIndices; };  
  52.   
  53.         // if the subclass has a threshold parameter, it need to override the following three functions  
  54.         // this is used for algorithms which have a normalization step on input samples  
  55.         virtual bool hasThreshold() const { return false; }  
  56.         virtual void setThreshold(double threshold) {}  
  57.         virtual double getThreshold() const { return 0; }  
  58.   
  59. protected:  
  60.         std::vector<int> defaultInlierIndices;  
  61.         std::vector<int> defaultOutlierIndices;  
  62. };  
  63.   
  64. //evaluate a model to samples  
  65. //using a threshold to distinguish inliers and outliers  
  66. template <typename M, typename S>  
  67. class GenericErrorCaclculator {  
  68. public:  
  69.         virtual ~GenericErrorCaclculator<M, S>() {}  
  70.   
  71.         typedef std::vector<S> Samples;  
  72.   
  73.         virtual double compute(const M& model, const S& sample) const = 0;  
  74.   
  75.         double computeAverage(const M& model, const Samples& samples) const {  
  76.                 int n = (int)samples.size();  
  77.                 if (n == 0) return 0;  
  78.                 double sum = 0;  
  79.                 for (int i = 0; i < n; i++) {  
  80.                         sum += compute(model, samples[i]);  
  81.                 }  
  82.                 return sum / n;  
  83.         }  
  84.   
  85.         double computeInlierAverage(const M& model, const Samples& samples) const {  
  86.                 int n = (int)samples.size();  
  87.                 if (n == 0) return 0;  
  88.                 double sum = 0;  
  89.                 double error = 0;  
  90.                 int inlierNum = 0;  
  91.                 for (int i = 0; i < n; i++) {  
  92.                         error = compute(model, samples[i]);  
  93.                         if (error <= threshold) {  
  94.                                 sum += error;  
  95.                                 inlierNum++;  
  96.                         }  
  97.                 }  
  98.                 if (inlierNum == 0) return 1000000;  
  99.                 return sum / inlierNum;  
  100.         }  
  101.   
  102. public:  
  103.   
  104.         /** set a threshold for classify inliers and outliers 
  105.          */  
  106.         void setThreshold(double v) { threshold = v; }  
  107.   
  108.         double getThreshold() const { return threshold; }  
  109.   
  110.         /** classify all samples to inliers and outliers 
  111.          *  
  112.          */  
  113.         void classify(const M& model, const Samples& samples, Samples& inliers, Samples& outliers) const {  
  114.                 inliers.clear();  
  115.                 outliers.clear();  
  116.                 Samples::const_iterator iter = samples.begin();  
  117.                 for (; iter != samples.end(); ++iter) {  
  118.                         if (isInlier(model, *iter)) inliers.push_back(*iter);  
  119.                         else outliers.push_back(*iter);  
  120.                 }  
  121.         }  
  122.   
  123.         /** classify all samples to inliers and outliers, output indices 
  124.          *  
  125.          */  
  126.         void classify(const M& model, const Samples& samples, std::vector<int>& inlierIndices, std::vector<int>& outlierIndices) const {  
  127.                 inlierIndices.clear();  
  128.                 outlierIndices.clear();  
  129.                 Samples::const_iterator iter = samples.begin();  
  130.                 int i = 0;  
  131.                 for (; iter != samples.end(); ++iter, ++i) {  
  132.                         if (isInlier(model, *iter)) inlierIndices.push_back(i);  
  133.                         else outlierIndices.push_back(i);  
  134.                 }  
  135.         }  
  136.   
  137.         /** classify all samples to inliers and outliers 
  138.          *  
  139.          */  
  140.         void classify(const M& model, const Samples& samples,   
  141.                 std::vector<int>& inlierIndices, std::vector<int>& outlierIndices,   
  142.                 Samples& inliers, Samples& outliers) const {  
  143.   
  144.                 inliers.clear();  
  145.                 outliers.clear();  
  146.                 inlierIndices.clear();  
  147.                 outlierIndices.clear();  
  148.                 Samples::const_iterator iter = samples.begin();  
  149.                 int i = 0;  
  150.                 for (; iter != samples.end(); ++iter, ++i) {  
  151.                         if (isInlier(model, *iter)) {  
  152.                                 inliers.push_back(*iter);  
  153.                                 inlierIndices.push_back(i);  
  154.                         }  
  155.                         else {  
  156.                                 outliers.push_back(*iter);  
  157.                                 outlierIndices.push_back(i);  
  158.                         }  
  159.                 }  
  160.         }  
  161.   
  162.         int calcInliersNumber(const M& model, const Samples& samples) const {  
  163.                 int n = 0;  
  164.                 for (int i = 0; i < (int)samples.size(); i++) {  
  165.                         if (isInlier(model, samples[i])) ++n;  
  166.                 }  
  167.                 return n;  
  168.         }  
  169.   
  170.         bool isInlier(const M& model, const S& sample) const {  
  171.                 return (compute(model, sample) <= threshold);  
  172.         }  
  173.   
  174. private:  
  175.         double threshold;  
  176. };  
  177.   
  178. /** generic RANSAC framework 
  179.  * make use of a model calculator and an error calculator 
  180.  * M is the model type, need to support copy assignment operator and default constructor 
  181.  * S is the sample type. 
  182.  * 
  183.  * Interface: 
  184.  *  M compute(samples); input a set of samples, output a model.  
  185.  *  after compute, inliers and outliers can be retrieved 
  186.  *  
  187.  */  
  188. template <typename M, typename S>  
  189. class Ransac : public GenericModelCalculator<M, S> {  
  190. public:  
  191.         typedef std::vector<S> Samples;  
  192.   
  193.         /** Constructor 
  194.          *  
  195.          * @param pmc a GenericModelCalculator object 
  196.          * @param modelSampleSize how much samples are used to fit a model 
  197.          * @param pec a GenericErrorCaclculator object 
  198.          */  
  199.         Ransac(GenericModelCalculator<M, S>* pmc, int modelSampleSize, GenericErrorCaclculator<M, S>* pec) {  
  200.                 this->pmc = pmc;  
  201.                 this->modelSampleSize = modelSampleSize;  
  202.                 this->pec = pec;  
  203.                 this->maxSampleCount = 500;  
  204.                 this->minInliersNum = 1000000;  
  205.   
  206.                 this->verbose = false;  
  207.         }  
  208.   
  209.         const GenericErrorCaclculator<M, S>* getErrorCalculator() const { return pec; }  
  210.   
  211.         virtual ~Ransac() {  
  212.                 delete pmc;  
  213.                 delete pec;  
  214.         }  
  215.   
  216.         void setMaxSampleCount(int n) {  
  217.                 this->maxSampleCount = n;  
  218.         }  
  219.   
  220.         void setMinInliersNum(int n) {  
  221.                 this->minInliersNum = n;  
  222.         }  
  223.   
  224.         virtual bool hasThreshold() const { return true; }  
  225.   
  226.         virtual void setThreshold(double threshold) {  
  227.                 pec->setThreshold(threshold);  
  228.         }  
  229.   
  230.         virtual double getThreshold() const {  
  231.                 return pec->getThreshold();  
  232.         }  
  233.   
  234. public:  
  235.         /** Given samples, compute a model that has most inliers. Assume the samples size is larger or equal than model sample size 
  236.          * inliers, outliers, inlierIndices and outlierIndices are stored 
  237.          *  
  238.          */  
  239.         M compute(const Samples& samples) {  
  240.                 clear();  
  241.   
  242.                 int pointsNumber = (int)samples.size();  
  243.   
  244.                 assert(pointsNumber >= modelSampleSize);  
  245.   
  246.                 long N = 100000;  
  247.                 int sampleCount = 0;  
  248.                 RANSAC_SamplesNumber ransac(modelSampleSize);  
  249.   
  250.                 M bestModel;  
  251.                 int maxInliersNumber = 0;  
  252.   
  253.                 bool stop = false;  
  254.                 while (sampleCount < N && sampleCount < maxSampleCount && !stop) {  
  255.   
  256.                         Samples nsamples;  
  257.                         randomlySampleN(samples, nsamples, modelSampleSize);  
  258.   
  259.                         M sampleModel = pmc->compute(nsamples);  
  260.                         if (maxInliersNumber == 0) bestModel = sampleModel;  //init bestModel  
  261.   
  262.                         int inliersNumber = pec->calcInliersNumber(sampleModel, samples);  
  263.                         if (verbose) std::cout<<"inliers number: "<<inliersNumber<<std::endl;  
  264.   
  265.                         if (inliersNumber > maxInliersNumber) {  
  266.                                 bestModel = sampleModel;  
  267.                                 maxInliersNumber = inliersNumber;  
  268.                                 N = ransac.calcN(inliersNumber, pointsNumber);  
  269.                                 if (maxInliersNumber > minInliersNum) stop = true;  
  270.                         }  
  271.   
  272.                         if (verbose) std::cout<<"N: "<<N<<std::endl;  
  273.   
  274.                         sampleCount ++;  
  275.                 }  
  276.   
  277.                 if (verbose) std::cout<<"sampleCount: "<<sampleCount<<std::endl;  
  278.   
  279.                 finalModel = computeUntilConverge(bestModel, maxInliersNumber, samples);  
  280.                   
  281.                 pec->classify(finalModel, samples, inlierIndices, outlierIndices, inliers, outliers);  
  282.   
  283.                 inliersRate = (double)inliers.size() / samples.size();  
  284.   
  285.                 return finalModel;  
  286.         }  
  287.   
  288.         const Samples& getInliers() const { return inliers; }  
  289.         const Samples& getOutliers() const { return outliers; }  
  290.   
  291.         const std::vector<int>& getInlierIndices() const { return inlierIndices; }  
  292.         const std::vector<int>& getOutlierIndices() const { return outlierIndices; }  
  293.   
  294.         double getInliersAverageError() const {  
  295.                 return pec->computeAverage(finalModel, inliers);  
  296.         }  
  297.   
  298.         double getInliersRate() const {  
  299.                 return inliersRate;  
  300.         }  
  301.   
  302.         void setVerbose(bool v) {  
  303.                 verbose = v;  
  304.         }  
  305.   
  306. private:  
  307.         void randomlySampleN(const Samples& samples, Samples& nsamples, int sampleSize) {  
  308.                 std::vector<int> is = ranis((int)samples.size(), sampleSize);  
  309.                 for (int i = 0; i < sampleSize; i++) {  
  310.                         nsamples.push_back(samples[is[i]]);  
  311.                 }  
  312.         }  
  313.   
  314.         /** from initial model, iterate to find the best model. 
  315.          * 
  316.          */  
  317.         M computeUntilConverge(M initModel, int initInliersNum, const Samples& samples) {  
  318.                 if (verbose) {  
  319.                         std::cout<<"iterate until converge...."<<std::endl;  
  320.                         std::cout<<"init inliers number: "<<initInliersNum<<std::endl;  
  321.                 }  
  322.   
  323.                 M bestModel = initModel;  
  324.                 M newModel = initModel;  
  325.   
  326.                 int lastInliersNum = initInliersNum;  
  327.   
  328.                 Samples newInliers, newOutliers;  
  329.                 pec->classify(initModel, samples, newInliers, newOutliers);  
  330.                 double lastInlierAverageError = pec->computeAverage(initModel, newInliers);  
  331.   
  332.                 if (verbose) std::cout<<"init inlier average error: "<<lastInlierAverageError<<std::endl;  
  333.   
  334.                 while (true && (int)newInliers.size() >= modelSampleSize) {  
  335.   
  336.                         //update new model with new inliers, the new model does not necessarily have more inliers  
  337.                         newModel = pmc->compute(newInliers);  
  338.   
  339.                         pec->classify(newModel, samples, newInliers, newOutliers);  
  340.   
  341.                         int newInliersNum = (int)newInliers.size();  
  342.                         double newInlierAverageError = pec->computeAverage(newModel, newInliers);  
  343.   
  344.                         if (verbose) {  
  345.                                 std::cout<<"new inliers number: "<<newInliersNum<<std::endl;  
  346.                                 std::cout<<"new inlier average error: "<<newInlierAverageError<<std::endl;  
  347.                         }  
  348.                         if (newInliersNum < lastInliersNum) break;  
  349.                         if (newInliersNum == lastInliersNum && newInlierAverageError >= lastInlierAverageError) break;  
  350.   
  351.                         //update best model with the model has more inliers  
  352.                         bestModel = newModel;  
  353.   
  354.                         lastInliersNum = newInliersNum;  
  355.                         lastInlierAverageError = newInlierAverageError;  
  356.                 }  
  357.   
  358.                 return bestModel;  
  359.         }  
  360.   
  361.         void clear() {  
  362.                 inliers.clear();  
  363.                 outliers.clear();  
  364.                 inlierIndices.clear();  
  365.                 outlierIndices.clear();  
  366.         }  
  367.   
  368. private:  
  369.         GenericModelCalculator<M, S>* pmc;  
  370.         GenericErrorCaclculator<M, S>* pec;  
  371.         int modelSampleSize;  
  372.   
  373.         int maxSampleCount;  
  374.         int minInliersNum;  
  375.   
  376.         M finalModel;  
  377.   
  378.         Samples inliers;  
  379.         Samples outliers;  
  380.   
  381.         std::vector<int> inlierIndices;  
  382.         std::vector<int> outlierIndices;  
  383.   
  384.         double inliersRate;  
  385.   
  386. private:  
  387.         bool verbose;  
  388.   
  389. };  
  390.   
  391. }  
  392. #endif // FVISION_RANSAC_H_  

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多