量子コンピュータの基本 - 線形代数の公式~スカラー変数のベクトル微分、行列微分
§ この記事の目的
量子コンピュータのみに限らず、機械学習や深層学習の理論でもよく出てくる線形代数の計算や式展開のうち、ベクトルや行列を用いた微分の計算の公式とその導出方法について確認します。
§ 微分公式のまとめ
まずは公式の一覧を示します。
1. ベクトル微分の公式
ここでは、xとaを列ベクトル、Aを行列とします。
∂x∂xTa=∂x∂aTx=a(式1)∂x∂xTAx=(A+AT)x(式2)∂x∂tr(xaT)=∂x∂tr(axT)=a(式3)∂x∂(a−Ax)T(a−Ax)=−2AT(a−Ax)(式4)
2. 行列微分の公式
ここでは、xとyを列ベクトル、XとAを行列とします。
また、∣A∣は行列Aの行列式を意味します。
∂X∂xTXy=xyT(式5)∂X∂xTX−1y=−X−1xyTX−1(式6)∂X∂log∣X∣=(X−1)T(式7)∂X∂tr(X)=I(式8)∂X∂tr(XA)=AT(式9)∂X∂tr(XTA)=A(式10)∂X∂tr(XAAT)=X(A+AT)(式10)∂x∂log∣X∣=tr(X−1∂x∂X)(式11)
スカラーyが行列Xの関数y=f(X)で表される場合、yの関数g(y)の行列Xでの微分は、
∂X∂g(y)=∂y∂g(y)∂X∂f(X)(式12)
§ 公式の導出
(式1)~(式4)までの導出方法を確認します。
厳密な証明でなく、あくまで一例ですのでご了承下さい。
(式5)以降は紙面の都合上、導出を割愛します。
1. (式1)の導出
例として3次元を取り使いますが、どの次元でも結果は同じとなります。
列ベクトルx,aをそれぞれ以下とします。
x=x1x2x3,a=a1a2a3
また、ベクトル微分∂x∂は以下のように作用するものとします。
∂x∂=∂x1∂∂x2∂∂x3∂
以下、導出例です。
∂x∂xTa=∂x∂⎩⎨⎧(x1x2x3)a1a2a3⎭⎬⎫=∂x∂(a1x1+a2x2+0a3x3)=∂x1∂∂x2∂∂x3∂(a1x1+a2x2+a3x3)=∂x1∂(a1x1+a2x2+a3x3)∂x2∂(a1x1+a2x2+a3x3)∂x3∂(a1x1+a2x2+a3x3)=a1a2a3=a
また、
∂x∂aTx=∂x∂⎩⎨⎧(a1a2a3)x1x2x3⎭⎬⎫=∂x∂(a1x1+a2x2+0a3x3)=∂x1∂∂x2∂∂x3∂(a1x1+a2x2+a3x3)=a1a2a3=a
2. (式2)の導出
行列Aを以下とします。
A=A11A21A31A12A22A32A13A23A33
以下、導出例です。
∂x∂xTAx=∂x∂⎩⎨⎧(x1x2x3)A11A21A31A12A22A32A13A23A33x1x2x3⎭⎬⎫=∂x∂⎩⎨⎧(A11x1+A21x2+A31x3A12x1+A22x2+A32x3A13x1+A23x2+A33x3)x1x2x3⎭⎬⎫=∂x∂(A11x1x1+A21x1x2+A31x1x3+A12x1x2+A22x2x2+A32x2x3+A13x1x3+A23x2x3+A33x3x3)=2A11x1+A21x2+A31x3+A12x2+A13x3A21x1+A12x1+2A22x2+A32x3+A23x3A31x1+A32x2+A13x1+A23x2+2A33x3=A11x1+A21x2+A31x3A12x1+A22x2+A32x3A13x1+A23x2+A33x3+A11x1+A12x2+A13x3A21x1+A22x2+A23x3A31x1+A32x2+A33x3=A11A12A13A21A22A23A31A32A33x1x2x3+A11A21A31A12A22A32A13A23A33x1x2x3=(AT+A)x=(A+AT)x
3. (式3)の導出
∂x∂tr(xaT)
ここで、
xaT=x1x2x3(a1a2a3)=a1x1a1x2a1x3a2x1a2x2a2x3a3x1a3x2a3x3
トレースを取ると、
tr(xaT)=a1x1+a2x2+a3x3
よって、
∂x∂tr(xaT)=∂x∂(a1x1+a2x2+a3x3)=∂x1∂(a1x1+a2x2+a3x3)∂x2∂(a1x1+a2x2+a3x3)∂x3∂(a1x1+a2x2+a3x3)=a1a2a3=a
また、
∂x∂tr(axT)
ここで、
axT=a1a2a3(x1x2x3)=a1x1a2x1a3x1a1x2a2x2a3x2a1x3a2x3a3x3
トレースを取ると、
tr(axT)=a1x1+a2x2+a3x3
これ以降は同じ導出のため省略します。
4. (式4)の導出
∂x∂(a−Ax)T(a−Ax)=∂x∂(aT−xTAT)(a−Ax)=∂x∂(aTa−aTAx−xTATa+xTATAx)
ここで第一項aTaはxに関わらないことから微分すると消える項のため除外します。
上式=∂x∂(−aTAx−xTATa+xTATAx)=∂x∂⎩⎨⎧−(a1a2a3)A11A21A31A12A22A32A13A23A33x1x2x3−(x1x2x3)A11A12A13A21A22A23A31A32A33a1a2a3+(x1x2x3)A11A12A13A21A22A23A31A32A33A11A21A31A12A22A32A13A23A33x1x2x3⎭⎬⎫=∂x∂⎩⎨⎧−(a1A11+a2A21+a3A31a1A12+a2A22+a3A32a1A13+a2A23+a3A33)x1x2x3−(a1A11+a2A12+a3A13a1A21+a2A22+a3A23a1A31+a2A32+a3A33)a1a2a3+(x1x2x3)B11B21B31B12B22B32B13B23B33x1x2x3⎭⎬⎫
ここで、行列Bを以下のように仮置きしました。
B=B11B21B31B12B22B32B13B23B33=A11A11+A21A21+A31A31A12A11+A22A21+A32A31A13A11+A23A21+A33A31A11A12+A21A22+A31A32A12A12+A22A22+A32A32A13A12+A23A22+A33A32A11A13+A21A23+A31A33A12A13+A22A23+A32A33A13A13+A23A23+A33A33
これより、
上式=∂x∂⎩⎨⎧−(x1(a1A11+a2A21+a3A31)+x2(a1A12+a2A22+a3A32)+x3(a1A13+a2A23+a3A33))−(a1(x1A11+x2A12+x3A13)+a2(x1A21+x2A22+x3A23)+a3(x1A31+x2A32+x3A33))+(x1B11+x2B21+x3B31x1B12+x2B22+x3B32x1B13+x2B23+x3B33)x1x2x3⎭⎬⎫=∂x∂{−x1(a1A11+a2A21+a3A31)−x2(a1A12+a2A22+a3A32)−x3(a1A13+a2A23+a3A33)−x1(a1A11+a2A21+a3A31)−x2(a1A12+a2A21+a3A32)−x3(a1A13+a2A23+a3A33)+x12B11+x1x2B21+x1x3B31+x1x2B12+x22B22+x2x3B32+x1x3B13+x2x3B23+x32B33}
ベクトル微分∂x∂を作用させると、
上式=−(a1A11+a2A21+a3A31)−(a1A11+a2A21+a3A31)+2x1B11+x2B21+x3B31+x2B12+x3B13−(a1A12+a2A22+a3A32)−(a1A12+a2A22+a3A32)+x1B21+x1B12+2x2B22+x3B32+x3B33−(a1A13+a2A23+a3A33)−(a1A13+a2A23+a3A33)+x1B31+x2B32+x1B13+x2B23+2x3B33
BをAで戻すと、
上式=−2(a1A11+a2A21+a3A31)+2x1(A11A11+A21A21+A31A31)+x2(A12A11+A22A21+A32A31)+x3(A13A11+A23A21+A33A31)+x2(A11A12+A21A22+A31A32)+x3(A11A13+A21A23+A31A33)−2(a1A12+a2A22+a3A32)+x1(A12A11+A22A21+A32A31)+x1(A11A12+A21A22+A31A32)+2x2(A12A12+A22A22+A32A32)+x3(A13A12+A23A22+A33A32)+x3(A12A13+A22A23+A32A33)−2(a1A13+a2A23+a3A33)+x1(A13A11+A23A21+A33A31)+x2(A13A12+A23A22+A33A32)+x1(A11A13+A21A23+A31A33)+x2(A12A13+A22A23+A32A33)+2x3(A13A13+A23A23+A33A33)=−2(a1A11+a2A21+a3A31)+2x1(A11A11+A21A21+A31A31)+2x2(A11A12+A21A22+A31A32)+2x3(A11A13+A21A23+A31A33)−2(a1A12+a2A22+a3A32)+2x1(A11A12+A21A22+A31A32)+2x2(A12A12+A22A22+A32A32)+2x3(A12A13+A22A23+A32A33)−2(a1A13+a2A23+a3A33)+2x1(A11A13+A21A23+A31A33)+2x2(A12A13+A22A23+A32A33)+2x3(A13A13+A23A23+A33A33)=−2a1A11+a2A21+a3A31a1A12+a2A22+a3A32a1A13+a2A23+a3A33+2A11A11+A21A21+A31A31A11A12+A21A22+A31A32A11A13+A21A23+A31A33A11A12+A21A22+A31A32A12A12+A22A22+A32A32A12A13+A22A23+A32A33A11A13+A21A23+A31A33A12A13+A22A23+A32A33A13A13+A23A23+A33A33x1x2x3=−2⎩⎨⎧A11A12A13A21A22A23A31A32A33a1a2a3−A11A12A13A21A22A23A31A32A33A11A21A31A12A22A32A13A23A33x1x2x3⎭⎬⎫=−2(ATa−ATAx)=−2AT(a−Ax)
5. (式5)以降の導出について
(式5)以降は省略しますが、以下のように行列微分を用いれば導出ができます。
∂X∂=∂X11∂∂X21∂∂X31∂∂X12∂∂X22∂∂X32∂∂X13∂∂X23∂∂X33∂