分享

梳理caffe代码blob(三)

 mscdj 2016-09-29

梳理caffe代码blob(三)

贯穿整个caffe的就是数据blob:

 

  1. #ifndef CAFFE_BLOB_HPP_  
  2. #define CAFFE_BLOB_HPP_  
  3.   
  4. #include <algorithm>  
  5. #include <string>  
  6. #include <vector>  
  7.   
  8. #include "caffe/common.hpp"  
  9. #include "caffe/proto/caffe.pb.h"  
  10. #include "caffe/syncedmem.hpp"  
  11. #include "caffe/util/math_functions.hpp"  
  12.   
  13. const int kMaxBlobAxes = INT_MAX;  
  14.   
  15. namespace caffe {  
  16.   
  17. /** 
  18.  * @brief A wrapper around SyncedMemory holders serving as the basic 
  19.  *        computational unit through which Layer%s, Net%s, and Solver%s 
  20.  *        interact. 
  21.  * 
  22.  * TODO(dox): more thorough description. 
  23.  */  
  24.   
  25.   
  26. template <typename Dtype>  
  27. class Blob {  
  28.  public:  
  29.   Blob()  
  30.        : data_(), diff_(), count_(0), capacity_(0) {}  
  31.   
  32.   /// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.  
  33.   //explicit关键字的作用是禁止单参数构造函数的隐式转换  
  34.   explicit Blob(const int num, const int channels, const int height,  
  35.       const int width);  
  36.   explicit Blob(const vector<int>& shape);  
  37.   
  38.   /// @brief Deprecated; use <code>Reshape(const vector<int>& shape)</code>.  
  39. /* 
  40. Reshape函数将num,channels,height,width传递给vector shape_  
  41. */  
  42.   void Reshape(const int num, const int channels, const int height,  
  43.       const int width);  
  44.  /** 
  45.  *Blob作为一个最基础的类,其中构造函数开辟一个内存空间来存储数据,Reshape函数在Layer中的 
  46.  *reshape或者forward 操作中来adjust the dimensions of a top blob。同时在改变Blob大小时, 
  47.  *内存将会被重新分配如果内存大小不够了,并且额外的内存将不会被释放。对input的blob进行reshape, 
  48.  *如果立马调用Net::Backward是会出错的,因为reshape之后,要么Net::forward或者Net::Reshape就会 
  49.  *被调用来将新的input shape 传播到高层 
  50.  */  
  51.   //根据shape来初始化shape_和shape_data_,以及为data_ 和diff_ 分配空间。   
  52.   void Reshape(const vector<int>& shape);  
  53.   void Reshape(const BlobShape& shape);  
  54.   void ReshapeLike(const Blob& other);  
  55.   //iniline主要是将代码进行复制,扩充,会使代码总量上升,好处就是可以节省调用的开销,以string形式获取shape_  
  56.   inline string shape_string() const {  
  57.     ostringstream stream;  
  58.     for (int i = 0; i < shape_.size(); ++i) {  
  59.       stream << shape_[i] << " ";  
  60.     }  
  61.     stream << "(" << count_ << ")";  
  62.     return stream.str();  
  63.   }  
  64. //获取shape_  
  65.   inline const vector<int>& shape() const { return shape_; }  
  66.   /** 
  67.    * @brief Returns the dimension of the index-th axis (or the negative index-th 
  68.    *        axis from the end, if index is negative). 
  69.    * 
  70.    * @param index the axis index, which may be negative as it will be 
  71.    *        "canonicalized" using CanonicalAxisIndex. 
  72.    *        Dies on out of range index. 
  73.    */  
  74. //获取index维的大小  
  75.   inline int shape(int index) const {  
  76.     return shape_[CanonicalAxisIndex(index)];  
  77.   }  
  78. //获取维的个数  
  79.   inline int num_axes() const { return shape_.size(); }  
  80. //获取当前data的大小  
  81.   inline int count() const { return count_; }  
  82.   
  83.   /** 
  84.    * @brief Compute the volume of a slice; i.e., the product of dimensions 
  85.    *        among a range of axes. 
  86.    * 
  87.    * @param start_axis The first axis to include in the slice. 
  88.    * 
  89.    * @param end_axis The first axis to exclude from the slice. 
  90.    */  
  91. /*多个count()函数,主要还是为了统计Blob的容量(volume),或者是某一片(slice), 
  92. 从某个axis到具体某个axis的shape乘积。 
  93. */  
  94. //获取某几维数据的大小  
  95.   inline int count(int start_axis, int end_axis) const {  
  96.     CHECK_LE(start_axis, end_axis);  
  97.     CHECK_GE(start_axis, 0);  
  98.     CHECK_GE(end_axis, 0);  
  99.     CHECK_LE(start_axis, num_axes());  
  100.     CHECK_LE(end_axis, num_axes());  
  101.     int count = 1;  
  102.     for (int i = start_axis; i < end_axis; ++i) {  
  103.       count *= shape(i);  
  104.     }  
  105.     return count;  
  106.   }  
  107.   /** 
  108.    * @brief Compute the volume of a slice spanning from a particular first 
  109.    *        axis to the final axis. 
  110.    * 
  111.    * @param start_axis The first axis to include in the slice. 
  112.    */  
  113. //获取某一维到结束数据的大小  
  114.   inline int count(int start_axis) const {  
  115.     return count(start_axis, num_axes());  
  116.   }  
  117.   
  118.   /** 
  119.    * @brief Returns the 'canonical' version of a (usually) user-specified axis, 
  120.    *        allowing for negative indexing (e.g., -1 for the last axis). 
  121.    * 
  122.    * @param index the axis index. 
  123.    *        If 0 <= index < num_axes(), return index. 
  124.    *        If -num_axes <= index <= -1, return (num_axes() - (-index)), 
  125.    *        e.g., the last axis index (num_axes() - 1) if index == -1, 
  126.    *        the second to last if index == -2, etc. 
  127.    *        Dies on out of range index. 
  128.    */  
  129.   //Blob的Index是可以从负坐标开始读的,标准化索引,主要是对参数索引进行标准化,以满足要求  
  130.   inline int CanonicalAxisIndex(int axis_index) const {  
  131.     CHECK_GE(axis_index, -num_axes())  
  132.         << "axis " << axis_index << " out of range for " << num_axes()  
  133.         << "-D Blob with shape " << shape_string();  
  134.     CHECK_LT(axis_index, num_axes())  
  135.         << "axis " << axis_index << " out of range for " << num_axes()  
  136.         << "-D Blob with shape " << shape_string();  
  137.     if (axis_index < 0) {  
  138.       return axis_index + num_axes();  
  139.     }  
  140.     return axis_index;  
  141.   }  
  142.   //Blob中的4个基本变量num,channel,height,width可以直接通过shape(0),shape(1),shape(2),shape(3)来访问  
  143.   /// @brief Deprecated legacy shape accessor num: use shape(0) instead.  
  144.   inline int num() const { return LegacyShape(0); }  
  145.   /// @brief Deprecated legacy shape accessor channels: use shape(1) instead.  
  146.   inline int channels() const { return LegacyShape(1); }  
  147.   /// @brief Deprecated legacy shape accessor height: use shape(2) instead.  
  148.   inline int height() const { return LegacyShape(2); }  
  149.   /// @brief Deprecated legacy shape accessor width: use shape(3) instead.  
  150.   inline int width() const { return LegacyShape(3); }  
  151. //data_维数不大于4时才能使用,功能同shape()类似。  
  152.   inline int LegacyShape(int index) const {  
  153.     CHECK_LE(num_axes(), 4)  
  154.         << "Cannot use legacy accessors on Blobs with > 4 axes.";  
  155.     CHECK_LT(index, 4);  
  156.     CHECK_GE(index, -4);  
  157.     if (index >= num_axes() || index < -num_axes()) {  
  158.       // Axis is out of range, but still in [0, 3] (or [-4, -1] for reverse  
  159.       // indexing) -- this special case simulates the one-padding used to fill  
  160.       // extraneous axes of legacy blobs.  
  161.       return 1;  
  162.     }  
  163.     return shape(index);  
  164.   }  
  165.   //计算offset,offset计算的方式也支持两种方式,一种直接指定n,c,h,w或者放到一个vector中进行计算,  
  166.   //偏差是根据对应的n,c,h,w,返回的offset是((n*channels()+c)*height()+h)*width()+w  
  167.   inline int offset(const int n, const int c = 0, const int h = 0,  
  168.       const int w = 0) const {  
  169.     CHECK_GE(n, 0);  
  170.     CHECK_LE(n, num());  
  171.     CHECK_GE(channels(), 0);  
  172.     CHECK_LE(c, channels());  
  173.     CHECK_GE(height(), 0);  
  174.     CHECK_LE(h, height());  
  175.     CHECK_GE(width(), 0);  
  176.     CHECK_LE(w, width());  
  177.     return ((n * channels() + c) * height() + h) * width() + w;  
  178.   }  
  179.   
  180.   inline int offset(const vector<int>& indices) const {  
  181.     CHECK_LE(indices.size(), num_axes());  
  182.     int offset = 0;  
  183.     for (int i = 0; i < num_axes(); ++i) {  
  184.       offset *= shape(i);  
  185.       if (indices.size() > i) {  
  186.         CHECK_GE(indices[i], 0);  
  187.         CHECK_LT(indices[i], shape(i));  
  188.         offset += indices[i];  
  189.       }  
  190.     }  
  191.     return offset;  
  192.   }  
  193.   /** 
  194.    * @brief Copy from a source Blob. 
  195.    * 
  196.    * @param source the Blob to copy from 
  197.    * @param copy_diff if false, copy the data; if true, copy the diff 
  198.    * @param reshape if false, require this Blob to be pre-shaped to the shape 
  199.    *        of other (and die otherwise); if true, Reshape this Blob to other's 
  200.    *        shape if necessary 
  201.    */  
  202.   //一个blob中copy数据 ,通过开关控制是否copy_diff,如果是False则copy data。reshape控制是否需要reshape  
  203.   void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,  
  204.       bool reshape = false);  
  205. /*这一部分函数主要通过给定的位置访问数据,根据位置计算与数据起始 
  206.   的偏差offset,在通过cpu_data*指针获得地址 
  207. */  
  208. //获取某位置的data_数据  
  209.   inline Dtype data_at(const int n, const int c, const int h,  
  210.       const int w) const {  
  211.     return cpu_data()[offset(n, c, h, w)];  
  212.   }  
  213. //获取某位置的diff_数据  
  214.   inline Dtype diff_at(const int n, const int c, const int h,  
  215.       const int w) const {  
  216.     return cpu_diff()[offset(n, c, h, w)];  
  217.   }  
  218.   
  219.   inline Dtype data_at(const vector<int>& index) const {  
  220.     return cpu_data()[offset(index)];  
  221.   }  
  222.   
  223.   inline Dtype diff_at(const vector<int>& index) const {  
  224.     return cpu_diff()[offset(index)];  
  225.   }  
  226. //获取data_  
  227.   inline const shared_ptr<SyncedMemory>& data() const {  
  228.     CHECK(data_);  
  229.     return data_;  
  230.   }  
  231. //获取diff_  
  232.   inline const shared_ptr<SyncedMemory>& diff() const {  
  233.     CHECK(diff_);  
  234.     return diff_;  
  235.   }  
  236.   //这里有data和diff两类数据,而这个diff就是我们所熟知的偏差,前者主要存储  
  237.   //前向传递的数据,而后者存储的是反向传播中的梯度  
  238.   const Dtype* cpu_data() const;//获取data_ cpu指针  
  239.   void set_cpu_data(Dtype* data);//设置data_的cpu指针,只是修改了指针  
  240.   const Dtype* gpu_data() const;//获取data_的gpu指针  
  241.   const Dtype* cpu_diff() const;//获取diff_的cpu指针  
  242.   const Dtype* gpu_diff() const;//获取diff_的gpu指针  
  243.   Dtype* mutable_cpu_data();//见SyncedMemory的mutable_cpu_data();  
  244.   Dtype* mutable_gpu_data();//见SyncedMemory的mutable_gpu_data();  
  245.   Dtype* mutable_cpu_diff();//见SyncedMemory的mutable_cpu_data();  
  246.   Dtype* mutable_gpu_diff();//见SyncedMemory的mutable_gpu_data();  
  247.   //更新data_的数据,减去diff_的数据  
  248.   void Update();  
  249. /* 
  250. 其中用到math_functions.hpp中的函数caffe_axpy(),该函数封装了cblas_saxpy,实现的是Y=alpha*X+Y。 
  251. 由此,知该函数的功能是data_=(data_-diff_)。另外,该函数只实现了对double和float型数据, 
  252. 对于unsigned int和int由于该函数主要是在Net中被调用,只有Blob<float>和Blob<double>型式, 
  253. 因此没有定义unsigned int和int。 
  254. */  
  255.   void FromProto(const BlobProto& proto, bool reshape = true);  
  256. /* 
  257. 由BlobProto对Blob进行赋值操作。reshape代表是否允许修改shape_的大小。 
  258. 需要注意的是再这里有double和float两种类型的数据 ,在代码中可以看到具体的体现 
  259. */  
  260.   void ToProto(BlobProto* proto, bool write_diff = false) const;  
  261.   
  262.   /// @brief Compute the sum of absolute values (L1 norm) of the data.  
  263. /* 
  264. 功能:计算L1范数 
  265. 说明:其中用到了math_function.hpp中的函数caffe_cpu_asum()和caffe_gpu_asum,实现的功能是对向量X求其每个元素绝对值的和,不同的是X分别在cpu和gpu中。 
  266. */  
  267.   Dtype asum_data() const;  
  268.   /// @brief Compute the sum of absolute values (L1 norm) of the diff.  
  269.   Dtype asum_diff() const;  
  270.   /// @brief Compute the sum of squares (L2 norm squared) of the data.  
  271. /* 
  272. 功能:计算L2范数。 
  273. 说明:用到了math_function.hpp中的caffe_cpu_dot(),caffe_cpu_strided_dot(),caffe_gpu_dot(), caffe_gpu_strided_dot()。具体就是就向量X的平方和。 
  274. */  
  275.   Dtype sumsq_data() const;  
  276.   /// @brief Compute the sum of squares (L2 norm squared) of the diff.  
  277.   Dtype sumsq_diff() const;  
  278.   
  279.   /// @brief Scale the blob data by a constant factor.  
  280. /* 
  281. 功能:正规化data_。 
  282. 说明:用到math_function.hpp中的caffe_scal()和caffe_gpu_scal()函数,就是对向量X乘上一个因子。 
  283. */  
  284.   void scale_data(Dtype scale_factor);  
  285.   /// @brief Scale the blob diff by a constant factor.  
  286.   void scale_diff(Dtype scale_factor);  
  287.   
  288.   /** 
  289.    * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the 
  290.    *        data_ of Blob other -- useful in Layer%s which simply perform a copy 
  291.    *        in their Forward pass. 
  292.    * 
  293.    * This deallocates the SyncedMemory holding this Blob's data_, as 
  294.    * shared_ptr calls its destructor when reset with the "=" operator. 
  295.    */  
  296.   void ShareData(const Blob& other);//本Blob共享other的data_  
  297.   /** 
  298.    * @brief Set the diff_ shared_ptr to point to the SyncedMemory holding the 
  299.    *        diff_ of Blob other -- useful in Layer%s which simply perform a copy 
  300.    *        in their Forward pass. 
  301.    * 
  302.    * This deallocates the SyncedMemory holding this Blob's diff_, as 
  303.    * shared_ptr calls its destructor when reset with the "=" operator. 
  304.    */  
  305.   void ShareDiff(const Blob& other);//本Blob共享other的diff_  
  306.   
  307.   bool ShapeEquals(const BlobProto& other);//判断other与本Blob形状是否相同。  
  308.   
  309.  protected:  
  310. //data_指针,指针类型是shared_ptr,属于boost库的一个智能指针,这一部分主要用来申请内存存储data,data主要是正向传播的时候用的  
  311.   shared_ptr<SyncedMemory> data_;  
  312. //diff_主要用来存储偏差,update data  
  313.   shared_ptr<SyncedMemory> diff_;  
  314. //shape_存储Blob的形状  
  315.   vector<int> shape_;  
  316. //count_表示Blob中的元素个数,也就是个数*通道数*高度*宽度  
  317.   int count_;  
  318. //capacity表示当前的元素个数,因为Blob可能会reshape  
  319.   int capacity_;  
  320.   
  321.   DISABLE_COPY_AND_ASSIGN(Blob);  
  322. };  // class Blob  
  323.   
  324. }  // namespace caffe  
  325.   
  326. #endif  // CAFFE_BLOB_HPP_  

顺便将实现部分也贴出来,方便对照:

  1. #include <climits>  
  2. #include <vector>  
  3.   
  4. #include "caffe/blob.hpp"  
  5. #include "caffe/common.hpp"  
  6. #include "caffe/syncedmem.hpp"  
  7. #include "caffe/util/math_functions.hpp"  
  8.   
  9. namespace caffe {  
  10.   
  11. template <typename Dtype>  
  12. //该函数将num,channels,height,width传递给vector shape_   
  13. void Blob<Dtype>::Reshape(const int num, const int channels, const int height,  
  14.     const int width) {  
  15.   vector<int> shape(4);  
  16.   shape[0] = num;  
  17.   shape[1] = channels;  
  18.   shape[2] = height;  
  19.   shape[3] = width;  
  20.   Reshape(shape);  
  21. }  
  22.   
  23. template <typename Dtype>  
  24. void Blob<Dtype>::Reshape(const vector<int>& shape) {  
  25.   CHECK_LE(shape.size(), kMaxBlobAxes);  
  26.   count_ = 1;  
  27.   shape_.resize(shape.size());//重新定义vector shape_ 的size  
  28.   for (int i = 0; i < shape.size(); ++i) {  
  29.     CHECK_GE(shape[i], 0);//确保shape 每个元素为正数  
  30.     CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";  
  31.     count_ *= shape[i];  
  32.     shape_[i] = shape[i];  
  33.   }  
  34.   //由于count_超过了当前capacity_ 因此需要重新分配内存空间  
  35.   if (count_ > capacity_) {  
  36.     capacity_ = count_;  
  37.     data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));  
  38.     diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));  
  39.   }  
  40. }  
  41.   
  42. template <typename Dtype>// BlobShape 在caffe.proto 中定义  
  43. void Blob<Dtype>::Reshape(const BlobShape& shape) {  
  44.   CHECK_LE(shape.dim_size(), kMaxBlobAxes);  
  45.   vector<int> shape_vec(shape.dim_size());  
  46.   for (int i = 0; i < shape.dim_size(); ++i) {  
  47.     shape_vec[i] = shape.dim(i);//dim 包含num,channels,height, width  
  48.   }  
  49.   Reshape(shape_vec);//用protobuf传递来dim 对shape_ 进行reshape  
  50. }  
  51. //用已知的Blob的shape来对shape_ 进行reshape  
  52. template <typename Dtype>  
  53. void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {  
  54.   Reshape(other.shape());  
  55. }  
  56. //用num,channels,height, width 初始化  
  57. template <typename Dtype>  
  58. Blob<Dtype>::Blob(const int num, const int channels, const int height,  
  59.     const int width)  
  60.   // capacity_ must be initialized before calling Reshape  
  61.   : capacity_(0) {  
  62.   Reshape(num, channels, height, width);  
  63. }  
  64. //用shape 初始化  
  65. template <typename Dtype>  
  66. Blob<Dtype>::Blob(const vector<int>& shape)  
  67.   // capacity_ must be initialized before calling Reshape  
  68.   : capacity_(0) {  
  69.   Reshape(shape);  
  70. }  
  71. //返回cpu 中的数据  
  72. template <typename Dtype>  
  73. const Dtype* Blob<Dtype>::cpu_data() const {  
  74.   CHECK(data_);  
  75.   return (const Dtype*)data_->cpu_data();  
  76. }  
  77. // 清空cpu 数据  
  78. template <typename Dtype>  
  79. void Blob<Dtype>::set_cpu_data(Dtype* data) {  
  80.   CHECK(data);  
  81.   data_->set_cpu_data(data);  
  82. }  
  83. //返回gpu 中的数据  
  84. template <typename Dtype>  
  85. const Dtype* Blob<Dtype>::gpu_data() const {  
  86.   CHECK(data_);  
  87.   return (const Dtype*)data_->gpu_data();  
  88. }  
  89. //反向传播导数diff_ 操作函数,返回cpu 中的数据  
  90. template <typename Dtype>  
  91. const Dtype* Blob<Dtype>::cpu_diff() const {  
  92.   CHECK(diff_);  
  93.   return (const Dtype*)diff_->cpu_data();  
  94. }  
  95. //返回gpu 中的数据  
  96. template <typename Dtype>  
  97. const Dtype* Blob<Dtype>::gpu_diff() const {  
  98.   CHECK(diff_);  
  99.   return (const Dtype*)diff_->gpu_data();  
  100. }  
  101.   
  102. template <typename Dtype>  
  103. Dtype* Blob<Dtype>::mutable_cpu_data() {  
  104.   CHECK(data_);  
  105.   return static_cast<Dtype*>(data_->mutable_cpu_data());  
  106. }  
  107.   
  108. template <typename Dtype>  
  109. Dtype* Blob<Dtype>::mutable_gpu_data() {  
  110.   CHECK(data_);  
  111.   return static_cast<Dtype*>(data_->mutable_gpu_data());  
  112. }  
  113.   
  114. template <typename Dtype>  
  115. Dtype* Blob<Dtype>::mutable_cpu_diff() {  
  116.   CHECK(diff_);  
  117.   return static_cast<Dtype*>(diff_->mutable_cpu_data());  
  118. }  
  119.   
  120. template <typename Dtype>  
  121. Dtype* Blob<Dtype>::mutable_gpu_diff() {  
  122.   CHECK(diff_);  
  123.   return static_cast<Dtype*>(diff_->mutable_gpu_data());  
  124. }  
  125. //当前的blob 的data_ 指向已知blob的数据  
  126. template <typename Dtype>  
  127. void Blob<Dtype>::ShareData(const Blob& other) {  
  128.   CHECK_EQ(count_, other.count());  
  129.   data_ = other.data();  
  130. }  
  131. //当前的blob 的diff_ 指向已知blob的反向传播导数  
  132. template <typename Dtype>  
  133. void Blob<Dtype>::ShareDiff(const Blob& other) {  
  134.   CHECK_EQ(count_, other.count());  
  135.   diff_ = other.diff();  
  136. }  
  137.   
  138. // The "update" method is used for parameter blobs in a Net, which are stored  
  139. // as Blob<float> or Blob<double> -- hence we do not define it for  
  140. // Blob<int> or Blob<unsigned int>.  
  141. template <> void Blob<unsigned int>::Update() { NOT_IMPLEMENTED; }  
  142. template <> void Blob<int>::Update() { NOT_IMPLEMENTED; }  
  143. //Updata函数用于参数blob的更新(weight,bias 等减去对应的导数)  
  144. template <typename Dtype>  
  145. void Blob<Dtype>::Update() {  
  146.   // We will perform update based on where the data is located.  
  147.   switch (data_->head()) {  
  148.   case SyncedMemory::HEAD_AT_CPU://数据在cpu上,则在cpu上进行计算  
  149.     // perform computation on CPU  
  150.     caffe_axpy<Dtype>(count_, Dtype(-1),  
  151.         static_cast<const Dtype*>(diff_->cpu_data()),  
  152.         static_cast<Dtype*>(data_->mutable_cpu_data()));  
  153.     break;  
  154.   case SyncedMemory::HEAD_AT_GPU:  
  155.   case SyncedMemory::SYNCED:  
  156. #ifndef CPU_ONLY//如果没有定义CPU_ONLY,且数据在gpu上,则在gpu上进行计算  
  157.     // perform computation on GPU  
  158.     caffe_gpu_axpy<Dtype>(count_, Dtype(-1),  
  159.         static_cast<const Dtype*>(diff_->gpu_data()),  
  160.         static_cast<Dtype*>(data_->mutable_gpu_data()));  
  161. #else  
  162.     NO_GPU;  
  163. #endif  
  164.     break;  
  165.   default:  
  166.     LOG(FATAL) << "Syncedmem not initialized.";  
  167.   }  
  168. }  
  169.   
  170. template <> unsigned int Blob<unsigned int>::asum_data() const {  
  171.   NOT_IMPLEMENTED;  
  172.   return 0;  
  173. }  
  174.   
  175. template <> int Blob<int>::asum_data() const {  
  176.   NOT_IMPLEMENTED;  
  177.   return 0;  
  178. }  
  179. //返回data_ 中所有 element 的绝对值之和  
  180. template <typename Dtype>  
  181. Dtype Blob<Dtype>::asum_data() const {  
  182.   if (!data_) { return 0; }  
  183.   switch (data_->head()) {  
  184.   case SyncedMemory::HEAD_AT_CPU:  
  185.     return caffe_cpu_asum(count_, cpu_data());  
  186.   case SyncedMemory::HEAD_AT_GPU:  
  187.   case SyncedMemory::SYNCED:  
  188. #ifndef CPU_ONLY  
  189.   {  
  190.     Dtype asum;  
  191.     caffe_gpu_asum(count_, gpu_data(), &asum);  
  192.     return asum;  
  193.   }  
  194. #else  
  195.     NO_GPU;  
  196. #endif  
  197.   case SyncedMemory::UNINITIALIZED:  
  198.     return 0;  
  199.   default:  
  200.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
  201.   }  
  202.   return 0;  
  203. }  
  204.   
  205. template <> unsigned int Blob<unsigned int>::asum_diff() const {  
  206.   NOT_IMPLEMENTED;  
  207.   return 0;  
  208. }  
  209.   
  210. template <> int Blob<int>::asum_diff() const {  
  211.   NOT_IMPLEMENTED;  
  212.   return 0;  
  213. }  
  214. //返回diff_ 中所有 element 的绝对值之和  
  215. template <typename Dtype>  
  216. Dtype Blob<Dtype>::asum_diff() const {  
  217.   if (!diff_) { return 0; }  
  218.   switch (diff_->head()) {  
  219.   case SyncedMemory::HEAD_AT_CPU:  
  220.     return caffe_cpu_asum(count_, cpu_diff());  
  221.   case SyncedMemory::HEAD_AT_GPU:  
  222.   case SyncedMemory::SYNCED:  
  223. #ifndef CPU_ONLY  
  224.   {  
  225.     Dtype asum;  
  226.     caffe_gpu_asum(count_, gpu_diff(), &asum);  
  227.     return asum;  
  228.   }  
  229. #else  
  230.     NO_GPU;  
  231. #endif  
  232.   case SyncedMemory::UNINITIALIZED:  
  233.     return 0;  
  234.   default:  
  235.     LOG(FATAL) << "Unknown SyncedMemory head state: " << diff_->head();  
  236.   }  
  237.   return 0;  
  238. }  
  239.   
  240. template <> unsigned int Blob<unsigned int>::sumsq_data() const {  
  241.   NOT_IMPLEMENTED;  
  242.   return 0;  
  243. }  
  244.   
  245. template <> int Blob<int>::sumsq_data() const {  
  246.   NOT_IMPLEMENTED;  
  247.   return 0;  
  248. }  
  249. //返回 data_ 中所有 element 的平方和  
  250. template <typename Dtype>  
  251. Dtype Blob<Dtype>::sumsq_data() const {  
  252.   Dtype sumsq;  
  253.   const Dtype* data;  
  254.   if (!data_) { return 0; }  
  255.   switch (data_->head()) {  
  256.   case SyncedMemory::HEAD_AT_CPU:  
  257.     data = cpu_data();  
  258.     sumsq = caffe_cpu_dot(count_, data, data);  
  259.     break;  
  260.   case SyncedMemory::HEAD_AT_GPU:  
  261.   case SyncedMemory::SYNCED:  
  262. #ifndef CPU_ONLY  
  263.     data = gpu_data();  
  264.     caffe_gpu_dot(count_, data, data, &sumsq);  
  265. #else  
  266.     NO_GPU;  
  267. #endif  
  268.     break;  
  269.   case SyncedMemory::UNINITIALIZED:  
  270.     return 0;  
  271.   default:  
  272.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
  273.   }  
  274.   return sumsq;  
  275. }  
  276.   
  277. template <> unsigned int Blob<unsigned int>::sumsq_diff() const {  
  278.   NOT_IMPLEMENTED;  
  279.   return 0;  
  280. }  
  281.   
  282. template <> int Blob<int>::sumsq_diff() const {  
  283.   NOT_IMPLEMENTED;  
  284.   return 0;  
  285. }  
  286. //返回 diff_ 中所有 element 的平方和  
  287. template <typename Dtype>  
  288. Dtype Blob<Dtype>::sumsq_diff() const {  
  289.   Dtype sumsq;  
  290.   const Dtype* diff;  
  291.   if (!diff_) { return 0; }  
  292.   switch (diff_->head()) {  
  293.   case SyncedMemory::HEAD_AT_CPU:  
  294.     diff = cpu_diff();  
  295.     sumsq = caffe_cpu_dot(count_, diff, diff);  
  296.     break;  
  297.   case SyncedMemory::HEAD_AT_GPU:  
  298.   case SyncedMemory::SYNCED:  
  299. #ifndef CPU_ONLY  
  300.     diff = gpu_diff();  
  301.     caffe_gpu_dot(count_, diff, diff, &sumsq);  
  302.     break;  
  303. #else  
  304.     NO_GPU;  
  305. #endif  
  306.   case SyncedMemory::UNINITIALIZED:  
  307.     return 0;  
  308.   default:  
  309.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
  310.   }  
  311.   return sumsq;  
  312. }  
  313.   
  314. template <> void Blob<unsigned int>::scale_data(unsigned int scale_factor) {  
  315.   NOT_IMPLEMENTED;  
  316. }  
  317.   
  318. template <> void Blob<int>::scale_data(int scale_factor) {  
  319.   NOT_IMPLEMENTED;  
  320. }  
  321. // 给data乘以scale_factor  
  322. template <typename Dtype>  
  323. void Blob<Dtype>::scale_data(Dtype scale_factor) {  
  324.   Dtype* data;  
  325.   if (!data_) { return; }  
  326.   switch (data_->head()) {  
  327.   case SyncedMemory::HEAD_AT_CPU:  
  328.     data = mutable_cpu_data();  
  329.     caffe_scal(count_, scale_factor, data);  
  330.     return;  
  331.   case SyncedMemory::HEAD_AT_GPU:  
  332.   case SyncedMemory::SYNCED:  
  333. #ifndef CPU_ONLY  
  334.     data = mutable_gpu_data();  
  335.     caffe_gpu_scal(count_, scale_factor, data);  
  336.     return;  
  337. #else  
  338.     NO_GPU;  
  339. #endif  
  340.   case SyncedMemory::UNINITIALIZED:  
  341.     return;  
  342.   default:  
  343.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
  344.   }  
  345. }  
  346.   
  347. template <> void Blob<unsigned int>::scale_diff(unsigned int scale_factor) {  
  348.   NOT_IMPLEMENTED;  
  349. }  
  350.   
  351. template <> void Blob<int>::scale_diff(int scale_factor) {  
  352.   NOT_IMPLEMENTED;  
  353. }  
  354. // 给diff乘以scale_factor  
  355. template <typename Dtype>  
  356. void Blob<Dtype>::scale_diff(Dtype scale_factor) {  
  357.   Dtype* diff;  
  358.   if (!diff_) { return; }  
  359.   switch (diff_->head()) {  
  360.   case SyncedMemory::HEAD_AT_CPU:  
  361.     diff = mutable_cpu_diff();  
  362.     caffe_scal(count_, scale_factor, diff);  
  363.     return;  
  364.   case SyncedMemory::HEAD_AT_GPU:  
  365.   case SyncedMemory::SYNCED:  
  366. #ifndef CPU_ONLY  
  367.     diff = mutable_gpu_diff();  
  368.     caffe_gpu_scal(count_, scale_factor, diff);  
  369.     return;  
  370. #else  
  371.     NO_GPU;  
  372. #endif  
  373.   case SyncedMemory::UNINITIALIZED:  
  374.     return;  
  375.   default:  
  376.     LOG(FATAL) << "Unknown SyncedMemory head state: " << diff_->head();  
  377.   }  
  378. }  
  379. //BlobProto 是定义在caffe.proto 中的一个message,其字段有 data,diff,shape,num,channels,height,width  
  380. template <typename Dtype>  
  381. bool Blob<Dtype>::ShapeEquals(const BlobProto& other) {  
  382.   if (other.has_num() || other.has_channels() ||  
  383.       other.has_height() || other.has_width()) {  
  384.     // Using deprecated 4D Blob dimensions --  
  385.     // shape is (num, channels, height, width).  
  386.     // Note: we do not use the normal Blob::num(), Blob::channels(), etc.  
  387.     // methods as these index from the beginning of the blob shape, where legacy  
  388.     // parameter blobs were indexed from the end of the blob shape (e.g., bias  
  389.     // Blob shape (1 x 1 x 1 x N), IP layer weight Blob shape (1 x 1 x M x N)).  
  390.     return shape_.size() <= 4 &&  
  391.            LegacyShape(-4) == other.num() &&  
  392.            LegacyShape(-3) == other.channels() &&  
  393.            LegacyShape(-2) == other.height() &&  
  394.            LegacyShape(-1) == other.width();  
  395.   }  
  396.   vector<int> other_shape(other.shape().dim_size());  
  397.   for (int i = 0; i < other.shape().dim_size(); ++i) {  
  398.     other_shape[i] = other.shape().dim(i);  
  399.   }  
  400.   return shape_ == other_shape;  
  401. }//检查当前的blob和已知的 other 的 shape 是否相同,相同返回true  
  402.   
  403. template <typename Dtype>  
  404. void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {  
  405.   if (source.count() != count_ || source.shape() != shape_) {  
  406.     if (reshape) {  
  407.       ReshapeLike(source);  
  408.     } else {  
  409.       LOG(FATAL) << "Trying to copy blobs of different sizes.";  
  410.     }  
  411.   }  
  412.   switch (Caffe::mode()) {  
  413.   case Caffe::GPU:  
  414.     if (copy_diff) {  
  415.       caffe_copy(count_, source.gpu_diff(),  
  416.           static_cast<Dtype*>(diff_->mutable_gpu_data()));  
  417.     } else {  
  418.       caffe_copy(count_, source.gpu_data(),  
  419.           static_cast<Dtype*>(data_->mutable_gpu_data()));  
  420.     }  
  421.     break;  
  422.   case Caffe::CPU:  
  423.     if (copy_diff) {  
  424.       caffe_copy(count_, source.cpu_diff(),  
  425.           static_cast<Dtype*>(diff_->mutable_cpu_data()));  
  426.     } else {  
  427.       caffe_copy(count_, source.cpu_data(),  
  428.           static_cast<Dtype*>(data_->mutable_cpu_data()));  
  429.     }  
  430.     break;  
  431.   default:  
  432.     LOG(FATAL) << "Unknown caffe mode.";  
  433.   }  
  434. }//从source 拷贝数据,copy_diff控制是拷贝diff还是data  
  435.   
  436. template <typename Dtype>  
  437. void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {  
  438.   if (reshape) {  
  439.     vector<int> shape;  
  440.     if (proto.has_num() || proto.has_channels() ||  
  441.         proto.has_height() || proto.has_width()) {  
  442.       // Using deprecated 4D Blob dimensions --  
  443.       // shape is (num, channels, height, width).  
  444.       shape.resize(4);  
  445.       shape[0] = proto.num();  
  446.       shape[1] = proto.channels();  
  447.       shape[2] = proto.height();  
  448.       shape[3] = proto.width();  
  449.     } else {  
  450.       shape.resize(proto.shape().dim_size());  
  451.       for (int i = 0; i < proto.shape().dim_size(); ++i) {  
  452.         shape[i] = proto.shape().dim(i);  
  453.       }  
  454.     }  
  455.     Reshape(shape);  
  456.   } else {//如果不做reshape要求当前的blob的shape和proto传入的shape相同  
  457.     CHECK(ShapeEquals(proto)) << "shape mismatch (reshape not set)";  
  458.   }  
  459.   // copy data  
  460.   Dtype* data_vec = mutable_cpu_data();  
  461.   for (int i = 0; i < count_; ++i) {  
  462.     data_vec[i] = proto.data(i);  
  463.   }//将proto传入的data拷贝到cpu数据  
  464.   if (proto.diff_size() > 0) {  
  465.     Dtype* diff_vec = mutable_cpu_diff();  
  466.     for (int i = 0; i < count_; ++i) {  
  467.       diff_vec[i] = proto.diff(i);  
  468.     }//将proto传入的diff 拷贝到cpu数据  
  469.   }  
  470. }  
  471.   
  472. template <typename Dtype>  
  473. void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {  
  474.   proto->clear_shape();  
  475.   for (int i = 0; i < shape_.size(); ++i) {  
  476.     proto->mutable_shape()->add_dim(shape_[i]);  
  477.   }  
  478.   proto->clear_data();  
  479.   proto->clear_diff();  
  480.   const Dtype* data_vec = cpu_data();  
  481.   for (int i = 0; i < count_; ++i) {  
  482.     proto->add_data(data_vec[i]);//将data写入proto  
  483.   }  
  484.   if (write_diff) {  
  485.     const Dtype* diff_vec = cpu_diff();  
  486.     for (int i = 0; i < count_; ++i) {  
  487.       proto->add_diff(diff_vec[i]);//将diff写入proto  
  488.     }  
  489.   }  
  490. }  
  491.   
  492. INSTANTIATE_CLASS(Blob);  
  493. template class Blob<int>;  
  494. template class Blob<unsigned int>;  
  495.   
  496. }  // namespace caffe  

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多