深度学习之矩阵形式的链式法则推导
深度学习之矩阵形式的链式法则推导
对于深度学习的基础“梯度下降”和“自动微分”的数学原理网上讲解的博客有很多了,但是目前没看到有讲关于矩阵形式的链式法则的内容,所以写了这篇笔记,供自己学习和复习。
1. 复习基础
我印象中本科生学习的传统微积分中,当时学的是只有标量多元函数才能求梯度,本科阶段也只介绍了标量链式法则。为了尽可能简洁明了地进行推导,首先复习几个简单的概念:
几个基本概念
- 方向导数:是标量多元函数沿指定方向的变化率,它在原函数的特定某点上是一个数。根据定义,方向导数是原函数沿着这一方向,自变量增加一个无穷小量后的微小变化量,除以这个无穷小量,再取极限得到的一个数。反映了原函数在这个点的特定方向上的函数值变化率。
- 偏导数:一个多元标量函数对每个自变量各有一个偏导数。其本质是标量多元函数沿着每个元的坐标轴正方向的方向导数。
- 梯度:是一个向量,向量的每个维度为原函数对该维度的偏导数。同时,其也是原函数在任一点沿所有方向的最大方向导数。它是原函数增长最快的方向。
我本科数学阶段老师上课所讲的多元分析学,研究的函数基本上都在标量多元函数范畴内。
根据国际惯例,本文把标量记作小写字母 x x x , 向量记作粗体小写字母 x \textbf{x} x或者带有箭头上标的小写字母 x ? \vec{x} x ,矩阵记作大写字母 X X X.
举例
下面用一个简单的二元标量函数 y = x 1 2 + x 2 2 y = x_{1}^{2} + x_{2}^{2} y=x12?+x22? 为例简要介绍以上基本概念:
-
方向导数
? 求上述二元标量函数 y = f ( x ? ) y=f(\vec{x}) y=f(x)在点 ( 1 , 1 ) (1,1) (1,1)沿方向 ( ? 1 , ? 1 ) (-1,-1) (?1,?1)的方向导数:
? 根据定义:
方向导数 = lim ? t → 0 f ( x + t ) ? f ( x ) t 其中 t 表示沿着指定方向的向量, t 表示方向向量 t 的模长 \textbf{方向导数} = \lim_{t \to 0} \frac{f(\textbf{x}+\textbf{t}) - f(\textbf{x} )}{t} \\ 其中\textbf{t}表示沿着指定方向的向量,t表示方向向量\textbf{t}的模长 方向导数=t→0lim?tf(x+t)?f(x)?其中t表示沿着指定方向的向量,t表示方向向量t的模长
? 代入数据:
函数 y = x 1 2 + x 2 2 在点 ( 1 , 1 ) 处 : f ( x ) = 2 将方向向量单位化: ( ? 1 2 , ? 1 2 ) , 则 lim ? t → 0 t ? = ( ? 1 2 t , ? 1 2 t ) : 在点 ( 1 ? 1 2 t , 1 ? 1 2 t ) 处 : f ( x+t ) = t 2 ? 2 2 t + 2 lim ? t → 0 f ( x + t ) ? f ( x ) t = t 2 ? 2 2 t t = ? 2 2 函数y = x_{1}^{2} + x_{2}^{2} \quad 在点(1,1)处:f(\textbf{x})=2 \\ 将方向向量单位化:(-\frac{1}{\sqrt{2}}, -\frac{1}{\sqrt{2}}),则\lim_{t\to0}\vec{t} = (-\frac{1}{\sqrt{2}}t, -\frac{1}{\sqrt{2}}t) : \\ 在点(1-\frac{1}{\sqrt{2}}t,1-\frac{1}{\sqrt{2}}t)处:f(\textbf{x+t}) = t^{2} - 2\sqrt{2}t + 2 \\ \lim_{t \to 0} \frac{f(\textbf{x}+\textbf{t}) - f(\textbf{x} )}{t} = \frac{t^{2}-2\sqrt{2}t}{t} = -2\sqrt{2} 函数y=x12?+x22?在点(1,1)处:f(x)=2将方向向量单位化:(?2?1?,?2?1?),则t→0lim?t=(?2?1?t,?2?1?t):在点(1?2?1?t,1?2?1?t)处:f(x+t)=t2?22?t+2t→0lim?tf(x+t)?f(x)?=tt2?22?t?=?22? -
偏导数(偏导函数)
?
求上述函数对于 x 1 x_{1} x1?的偏导(函)数:
?
根据定义:
f ′ ( x 1 ) = ? y ? x 1 = 2 x 1 {f}'(x_{1}) = \frac{\partial y}{\partial x_{1}} = 2x_{1} f′(x1?)=?x1??y?=2x1? -
梯度
? 求原函数在点 ( 1 , 1 ) (1,1) (1,1) 处的梯度:
? 根据定义:
▽ f = ( f ′ ( x 1 ) , f ′ ( x 2 ) , ? ? , f ′ ( x n ) ) = ( ? y ? x 1 , ? y ? x 2 , ? ? , ? y ? x n ) \begin{align*} \bigtriangledown f &= (f'(x_{1}), f'(x_{2}), \cdots, f'(x_{n})) \\ & = (\frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}} ) \end{align*} ▽f?=(f′(x1?),f′(x2?),?,f′(xn?))=(?x1??y?,?x2??y?,?,?xn??y?)?
? 代入数据得:
▽ f = ( f ′ ( x 1 ) , f ′ ( x 2 ) ) = ( ? y ? x 1 , ? y ? x 2 ) = ( 2 x 1 , 2 x 2 ) = ( 2 , 2 ) \begin{align*} \bigtriangledown f &= (f'(x_{1}), f'(x_{2})) \\ & = (\frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}) \\ & = (2x_{1}, 2x_{2}) \\ & = (2,2) \end{align*} ▽f?=(f′(x1?),f′(x2?))=(?x1??y?,?x2??y?)=(2x1?,2x2?)=(2,2)?
2.标量形式的链式法则
这个是本科工科数学一元分析学的重点,此处不用多做证明,只简单地记录一下:
若 y = f ( u ) , u = g ( x ) ,其中 y , u , x 均为标量,则 : ? y ? x = ? y ? u ? ? u ? x 若y=f(u), u=g(x),其中y,u,x均为标量,则:\\ \frac{\partial y}{\partial x} = \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial x} 若y=f(u),u=g(x),其中y,u,x均为标量,则:?x?y?=?u?y???x?u?
拓展到向量
设 y = f ( u ) , u = g ( x ? ) y=f(u), u=g(\vec{x}) y=f(u),u=g(x), 其中 x ? = ( x 1 , x 2 , ? ? , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1?,x2?,?,xn?) , y和u均为标量
则有:
? y ? x = ? y ? u ? ? u ? x ( 1 , n ) = 1 ? ( 1 , n ) \frac{\partial y}{\partial \textbf{x}} = \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial \textbf{x}} \\ (1,n) = 1\cdot(1,n) ?x?y?=?u?y???x?u?(1,n)=1?(1,n)
更具体的展开:
? y ? x = ( ? y ? x 1 , ? y ? x 2 , ? ? , ? y ? x n ) = ( ? y ? u ? ? u ? x 1 , ? y ? u ? ? u ? x 2 , ? ? , ? y ? u ? ? u ? x n ) = ( ? y ? u ) ? ( ? u ? x 1 , ? u ? x 2 , ? ? , ? u ? x n ) \begin{align*} \frac{\partial y}{\partial \textbf{x}} &= ( \frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}}) \\ &= ( \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{1}}, \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial u}\cdot\frac{\partial u}{\partial x_{n}}) \\ &= (\frac{\partial y}{\partial u}) \cdot ( \frac{\partial u}{\partial x_{1}}, \frac{\partial u}{\partial x_{2}}, \cdots, \frac{\partial u}{\partial x_{n}}) \end{align*} ?x?y??=(?x1??y?,?x2??y?,?,?xn??y?)=(?u?y???x1??u?,?u?y???x2??u?,?,?u?y???xn??u?)=(?u?y?)?(?x1??u?,?x2??u?,?,?xn??u?)?
这是显然的。
3.矩阵形式的链式法则
前面提到的例子全部都只涉及到对标量多元函数求偏导数,这是本科的工科数学中就很熟悉的内容。下面介绍的矩阵形式的链式法则均涉及到向量多元函数的偏导数。对向量多元函数求梯度得到的是一个矩阵。
仅中间变量是向量
设 y = f ( u ? ) , u ? = g ( x ? ) y=f(\vec{u}), \vec{u}=g(\vec{x}) y=f(u),u=g(x), 其中 u ? = ( u 1 , u 2 , ? ? , u k ) \vec{u} = (u_{1},u_{2},\cdots,u_{k}) u=(u1?,u2?,?,uk?) , x ? = ( x 1 , x 2 , ? ? , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1?,x2?,?,xn?) , y为标量
链式法则的具体展开:
? y ? x = ( ? y ? x 1 , ? y ? x 2 , ? ? , ? y ? x n ) = ( ? y ? u ? ? u ? x 1 , ? y ? u ? ? u ? x 2 , ? ? , ? y ? u ? ? u ? x n ) = ( ? y ? u ) ? ( ? u ? x 1 , ? u ? x 2 , ? ? , ? u ? x n ) = ( ? y ? u 1 , ? y ? u 2 , ? ? , ? y ? u k ) ? [ ? u 1 ? x ? u 2 ? x ? ? u k ? x ] = ( ? y ? u 1 , ? y ? u 2 , ? ? , ? y ? u k ) ? [ ? u 1 ? x 1 ? u 1 ? x 2 ? ? u 1 ? x n ? u 2 ? x 1 ? u 2 ? x 2 ? ? u 2 ? x n ? ? ? ? ? u k ? x 1 ? u k ? x 2 ? ? u k ? x n ] \begin{align*} \frac{\partial y}{\partial \textbf{x}} &= ( \frac{\partial y}{\partial x_{1}}, \frac{\partial y}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial x_{n}}) \\ &= ( \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial y}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial y}{\partial \textbf{u}}) \cdot ( \frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial {y}}{\partial u_{1}}, \frac{\partial {y}}{\partial u_{2}}, \cdots, \frac{\partial {y}}{\partial u_{k}}) \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial \textbf{x}} \\ \frac{\partial u_{2}}{\partial \textbf{x}}\\ \cdots \\ \frac{\partial u_{k}}{\partial \textbf{x}} \end{bmatrix} \\ &= (\frac{\partial {y}}{\partial u_{1}}, \frac{\partial {y}}{\partial u_{2}}, \cdots, \frac{\partial {y}}{\partial u_{k}}) \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial x_{1}}& \frac{\partial u_{1}}{\partial x_{2}}& \cdots & \frac{\partial u_{1}}{\partial x_{n}}\\ \frac{\partial u_{2}}{\partial x_{1}}& \frac{\partial u_{2}}{\partial x_{2}}& \cdots & \frac{\partial u_{2}}{\partial x_{n}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial u_{k}}{\partial x_{1}}& \frac{\partial u_{k}}{\partial x_{2}}& \cdots & \frac{\partial u_{k}}{\partial x_{n}} \end{bmatrix} \end{align*} ?x?y??=(?x1??y?,?x2??y?,?,?xn??y?)=(?u?y???x1??u?,?u?y???x2??u?,?,?u?y???xn??u?)=(?u?y?)?(?x1??u?,?x2??u?,?,?xn??u?)=(?u1??y?,?u2??y?,?,?uk??y?)? ??x?u1???x?u2????x?uk??? ?=(?u1??y?,?u2??y?,?,?uk??y?)? ??x1??u1???x1??u2????x1??uk????x2??u1???x2??u2????x2??uk?????????xn??u1???xn??u2????xn??uk??? ??
即:
? y ? x = ? y ? u ? ? u ? x ( 1 , n ) = ( 1 , k ) ? ( k , n ) \frac{\partial y}{\partial \textbf{x}} = \frac{\partial y}{\partial \textbf{u}} \cdot \frac{\partial \textbf{u}}{\partial \textbf{x}} \\ (1,n) = (1,k)\cdot(k,n) ?x?y?=?u?y???x?u?(1,n)=(1,k)?(k,n)
所有变量均为向量
设 y ? = f ( u ? ) , u ? = g ( x ? ) \vec{y}=f(\vec{u}), \vec{u}=g(\vec{x}) y?=f(u),u=g(x), 其中 y ? = ( y 1 , y 2 , ? ? , y m ) \vec{y} = (y_{1},y_{2},\cdots,y_{m}) y?=(y1?,y2?,?,ym?) , u ? = ( u 1 , u 2 , ? ? , u k ) \vec{u} = (u_{1},u_{2},\cdots,u_{k}) u=(u1?,u2?,?,uk?) , x ? = ( x 1 , x 2 , ? ? , x n ) \vec{x} = (x_{1},x_{2},\cdots,x_{n}) x=(x1?,x2?,?,xn?)
链式法则的具体展开:
? y ? x = ( ? y ? x 1 , ? y ? x 2 , ? ? , ? y ? x n ) = ( ? y ? u ? ? u ? x 1 , ? y ? u ? ? u ? x 2 , ? ? , ? y ? u ? ? u ? x n ) = ( ? y ? u ) ? ( ? u ? x 1 , ? u ? x 2 , ? ? , ? u ? x n ) = [ ? y 1 ? u 1 ? y 1 ? u 2 ? ? y 1 ? u k ? y 2 ? u 1 ? y 2 ? u 2 ? ? y 2 ? u k ? ? ? ? ? y m ? u 1 ? y m ? u 2 ? ? y m ? u k ] ? [ ? u 1 ? x 1 ? u 1 ? x 2 ? ? u 1 ? x n ? u 2 ? x 1 ? u 2 ? x 2 ? ? u 2 ? x n ? ? ? ? ? u k ? x 1 ? u k ? x 2 ? ? u k ? x n ] \begin{align*} \frac{\partial \textbf{y}}{\partial \textbf{x}} &= ( \frac{\partial \textbf{y}}{\partial x_{1}}, \frac{\partial \textbf{y}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{y}}{\partial x_{n}}) \\ &= ( \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{y}}{\partial \textbf{u}}\cdot\frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= (\frac{\partial \textbf{y}}{\partial \textbf{u}}) \cdot ( \frac{\partial \textbf{u}}{\partial x_{1}}, \frac{\partial \textbf{u}}{\partial x_{2}}, \cdots, \frac{\partial \textbf{u}}{\partial x_{n}}) \\ &= \begin{bmatrix} \frac{\partial y_{1}}{\partial u_{1}}& \frac{\partial y_{1}}{\partial u_{2}}& \cdots & \frac{\partial y_{1}}{\partial u_{k}}\\ \frac{\partial y_{2}}{\partial u_{1}}& \frac{\partial y_{2}}{\partial u_{2}}& \cdots & \frac{\partial y_{2}}{\partial u_{k}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial y_{m}}{\partial u_{1}}& \frac{\partial y_{m}}{\partial u_{2}}& \cdots & \frac{\partial y_{m}}{\partial u_{k}} \end{bmatrix} \cdot \begin{bmatrix} \frac{\partial u_{1}}{\partial x_{1}}& \frac{\partial u_{1}}{\partial x_{2}}& \cdots & \frac{\partial u_{1}}{\partial x_{n}}\\ \frac{\partial u_{2}}{\partial x_{1}}& \frac{\partial u_{2}}{\partial x_{2}}& \cdots & \frac{\partial u_{2}}{\partial x_{n}} \\ \cdots &\cdots & \cdots & \cdots \\ \frac{\partial u_{k}}{\partial x_{1}}& \frac{\partial u_{k}}{\partial x_{2}}& \cdots & \frac{\partial u_{k}}{\partial x_{n}} \end{bmatrix} \end{align*} ?x?y??=(?x1??y?,?x2??y?,?,?xn??y?)=(?u?y???x1??u?,?u?y???x2??u?,?,?u?y???xn??u?)=(?u?y?)?(?x1??u?,?x2??u?,?,?xn??u?)= ??u1??y1???u1??y2????u1??ym????u2??y1???u2??y2????u2??ym?????????uk??y1???uk??y2????uk??ym??? ?? ??x1??u1???x1??u2????x1??uk????x2??u1???x2??u2????x2??uk?????????xn??u1???xn??u2????xn??uk??? ??
即:
? y ? x = ? y ? u ? ? u ? x ( m , n ) = ( m , k ) ? ( k , n ) \frac{\partial \textbf{y}}{\partial \textbf{x}} = \frac{\partial \textbf{y}}{\partial \textbf{u}} \cdot \frac{\partial \textbf{u}}{\partial \textbf{x}} \\ (m,n) = (m,k)\cdot(k,n) ?x?y?=?u?y???x?u?(m,n)=(m,k)?(k,n)
如果将此时的链式法则画出计算图,可以清晰地看出,向量函数 y \textbf{y} y 对 向量 x \textbf{x} x 求偏导,实际上就是遍历了从y到x的所有依赖关系。这就是上面这个矩阵相乘的本质。 ]
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!