分享

反向传播的全矩阵方法

 netouch 2024-04-26 发布于北京

之前在神经网络随机梯度下降计算梯度时,在反向传播时每个样本单独计算梯度,然后再求小批量数据的梯度平均值;而现在全矩阵方法是将整个小批量作为一个矩阵(一个样本作为一列)输入整体利用矩阵运算一次计算梯度平均值,用计算出的梯度平均值去更新权重和偏置。结果表明,全矩阵方法能够提升效率平均5倍左右,由开始的平均10秒到2秒
废话不多说,直接上代码:

# ⼩批量数据上的反向传播的全矩阵⽅法,并且最后更新权重
    def backprop_matrix(self, x, y, m, eta):
        '''
        ⼩批量数据上的反向传播的全矩阵⽅法
        :param x: 小批量数据的输入矩阵,一列代表一个样本
        :param y: 期望输出矩阵
        :param m: 数据的规模
        :param eta: 学习率
        :return:
        '''
        # 根据权重矩阵和偏置列向量的形状生成梯度矩阵
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        # 第一步,设置输入激活值矩阵
        activation = x
        activations = [x]   # 保存各层的激活值矩阵
        zs = []   # 保存各层的带权输入矩阵
        # 第二步,前向传播,计算各层的带权输入和激活值
        for w, b in zip(self.weights, self.biases):
            z = np.dot(w, activation) + b
            activation = sigmoid(z)
            zs.append(z)
            activations.append(activation)
        # 第三步,计算输出层误差矩阵
        delta = cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
        # 计算输出层的偏置和权重的梯度
        nabla_b[-1] = np.array([np.mean(delta, axis=1)]).transpose()
        # self.biases[-1] = self.biases[-1] - eta * nabla_b
        nabla_w[-1] = np.dot(delta, activations[-2].transpose()) / m
        # self.weights[-1] = self.weights[-1] - eta * nabla_w
        # 第四步,反向传播误差,并且用误差计算梯度
        for l in range(2, self.num_layers):
            delta = np.dot(self.weights[-l+1].transpose(), delta) * sigmoid_prime(zs[-l])
            nabla_b[-l] = np.array([np.mean(delta, axis=1)]).transpose()
            nabla_w[-l] = np.dot(delta, activations[-l-1].transpose()) / m
        # 第五步,梯度下降更新参数
        for l in range(1, self.num_layers):
            self.biases[-l] = self.biases[-l] - eta * nabla_b[-l]
            self.weights[-l] = self.weights[-l] - eta * nabla_w[-l]

最后在具体操作的过程在有什么问题,欢迎大家一起交流讨论。
在下编程小白,如果有什么错误欢迎大家批评指正!
邮箱:1916728303@qq.com

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多