JAXライブラリを使って高速に計算ができるという噂を聞き、まずは現在行なっているMFの計算を行ってみました。基本的には前回のPyTorchの記事を参照します。
I heard a rumor that JAX library allows for fast computation, so I decided to try it out by performing matrix factorization (MF), which I had previously implemented using PyTorch. For the most part, I will refer to my previous PyTorch article.
Matrix Factorization with PyTorch
https://blueqat.com/yuichiro_minato2/9861dd84-010d-4258-8899-6b173406136a
ライブラリを読み込み、参照記事と同じように初期化をします。多少お作法に違いがあります。
I will import the library and initialize it similar to the reference article. There may be slight differences in the syntax due to conventions.
import jax.numpy as jnp
from jax import grad, jit
from jax import random
%matplotlib inline
#initial matrix
R = jnp.array([[1.,1,0,3],[2,5,0,5],[3,1,2,2],[0,1,3,0],[1,0,3,1]])
#get number of rows and columns
rows, cols = R.shape
#latent variable
r = 2
#initialize two matrix with random numbers
seed = 1701
key = random.PRNGKey(seed)
P = random.uniform(key, shape=(r, rows))
Q = random.uniform(key, shape=(r, cols))
#learning rate
e = 0.01
次に損失関数を決めます。損失関数は最初のパラメータの微分が出るっぽかったので、PとQを別々に求めたいので、二つ作りました。また更新は最急降下法を書いてみてます。
Next, I will determine the loss function. Since it seemed that the first derivative of the parameters would be available, I decided to create two separate loss functions for P and Q to calculate them independently. Additionally, I am attempting to implement gradient descent for the updates.
def loss_P(Pt, Rt, Qt):
non_zero_mask = Rt != 0
filtered_matrix = (Rt-Pt.T@Qt) * non_zero_mask
return jnp.sum(jnp.square(filtered_matrix))
def loss_Q(Qt, Rt, Pt):
non_zero_mask = Rt != 0
filtered_matrix = (Rt-Pt.T@Qt) * non_zero_mask
return jnp.sum(jnp.square(filtered_matrix))
@jit
def update_P(Pt, Rt, Qt):
grads_P = grad(loss_P)(Pt, Rt, Qt)
return Pt - e*grads_P
@jit
def update_Q(Qt, Rt, Pt):
grads_Q = grad(loss_Q)(Qt, Rt, Pt)
return Qt - e*grads_Q
arr = []
for epoch in range(100):
P = update_P(P, R, Q)
Q = update_Q(Q, R, P)
arr.append(loss_P(P, R, Q))
plt.plot(arr)
plt.show()
最後にループを回して誤差をプロットしてみました。きちんと学習できていそうです。
Finally, I ran a loop and plotted the errors. It seems that the model is learning properly.
P.T@Q
Array([[0.45281273, 1.462819 , 3.1225102 , 2.8387036 ],
[2.618786 , 4.3629026 , 7.2835083 , 5.19678 ],
[1.4783977 , 1.8638482 , 2.5360398 , 1.2928767 ],
[0.505274 , 1.3719578 , 2.7996037 , 2.4546168 ],
[1.4598143 , 1.8708394 , 2.5841634 , 1.3599337 ]], dtype=float32)