以前は技術を概観しました。今回は実装を見てみたいと思います。
[論文解説]3Dガウシアンスプラッティング
https://blueqat.com/yuichiro_minato2/627fb949-37d7-463c-a2ff-57a3bb0e95f5
あと、なんかものすごいわかりやすい解説を見つけました。投稿されたばかりのようです。
3D Gaussian SplattingはNeRFをこえるかトレンドになるか?複数視点の画像から3D空間を再現する最新手法論文解説!
https://qiita.com/RyeWiskey/items/9ccc862db91e38e8bbc9
さて、こちらにPython実装がありました。
notebook形式で
https://github.com/thomasantony/splat/
上記のnotebook形式のもとのpythonコード
https://github.com/limacv/GaussianSplattingViewer/tree/main
コード自体はファイルが三つで両方とも同じです。
main.py
util.py
util_gau.py
のファイルとなっていました。この実装では、基本的には学習済みの3Dガウス関数のパラメータを読み込み、それを元にレンダリングを行うという実装になっています。
util.pyファイルには「カメラのセッティング」が入っています。
from OpenGL.GL import *
import OpenGL.GL.shaders as shaders
import numpy as np
import glm
import ctypes
また、util_gau.pyファイルにはガウス関数の読み込みや設定値が書いてあります。
import numpy as np
from plyfile import PlyData
from dataclasses import dataclass
例えば、
xyz: np.ndarray
rot: np.ndarray
scale: np.ndarray
opacity: np.ndarray
sh: np.ndarray
のようになっていて、xyzの座標、回転方向?、スケール、不透明度、SH係数(spherical harmonicsで球面調和関数のようです)となっていました。
あとは、初期値でしょうか?
def naive_gaussian():
gau_xyz = np.array([
0, 0, 0,
1, 0, 0,
0, 1, 0,
0, 0, 1,
]).astype(np.float32).reshape(-1, 3)
gau_rot = np.array([
1, 0, 0, 0,
1, 0, 0, 0,
1, 0, 0, 0,
1, 0, 0, 0
]).astype(np.float32).reshape(-1, 4)
gau_s = np.array([
0.03, 0.03, 0.03,
0.2, 0.03, 0.03,
0.03, 0.2, 0.03,
0.03, 0.03, 0.2
]).astype(np.float32).reshape(-1, 3)
gau_c = np.array([
1, 0, 1,
1, 0, 0,
0, 1, 0,
0, 0, 1,
]).astype(np.float32).reshape(-1, 3)
gau_c = (gau_c - 0.5) / 0.28209
gau_a = np.array([
1, 1, 1, 1
]).astype(np.float32).reshape(-1, 1)
return GaussianData(
gau_xyz,
gau_rot,
gau_s,
gau_a,
gau_c
)
色の情報がどこに入っているかはわかりませんでしたが、どうやら後の方のファイルでSHから色を計算しているようでした。
あとは、シンプルで今回の3Dガウシアンスプラッティング向けのplyファイルを読み込む関数がありました。
def load_ply(path):
補助のファイルはこの二つだけでかなりシンプルです。
次にメインのファイルを見てみます。今回の例題ファイルで読み込むツールは、
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import util
from util import Camera
from util_gau import load_ply, naive_gaussian, GaussianData
極めてシンプルですね。わかりやすくて好感が持てました。ちなみにrequirements.txtの中は、
glfw
PyGLM
imgui
PyOpenGL
numpy
imageio
plyfile
tqdm
scipy
matplotlib
となっています。シンプルで好きです。自分の環境では、google colabでは動かなかったので、localでjupyternotebookを立ち上げてそちらで実行したらうまくいきました。
メインのGaussianのクラスの中には、
これは三次元ガウス関数の共分散行列を返すための関数でしょうか
def compute_cov3d(self):
cov3D = np.diag(self.scale**2)
cov3D = self.rot.as_matrix().T @ cov3D @ self.rot.as_matrix()
return cov3D
こちらは三次元の共分散行列とカメラからカメラに向かって二次元の共分散行列を取り出すための関数
def get_cov2d(self, camera):
view_mat = camera.get_view_matrix()
g_pos_w = np.append(self.pos, 1.0)
(略)
こちらは深度?
def get_depth(self, camera):
view_matrix = camera.get_view_matrix()
position4 = np.append(self.pos, 1.0)
g_pos_view = view_matrix @ position4
depth = g_pos_view[2]
return depth
なんか円錐を作っている模様?
def get_conic_and_bb(self, camera):
cov2d = self.get_cov2d(camera)
det = np.linalg.det(cov2d)
if det == 0.0:
return None
det_inv = 1.0 / det
(略)
こちらは色です。どうやらやはり球面調和関数と方向から色を取るようです。
def get_color(self, dir) -> np.ndarray:
"""Samples spherical harmonics to get color for given view direction"""
c0 = self.sh[0:3] # f_dc_* from the ply file)
color = SH_C0 * c0
(略)
ここまでが基本的なGaussianのクラスの中身でした。次は普通の関数が続きます。
作った円錐を配置するみたいです。引数にはgaussianのオブジェクトとカメラが格納されていますので、レンダリングの前準備のようです。
def plot_conics_and_bbs(gaussian_objects, camera):
その次の関数は、どうやら最終的に画像を作る直前の状態を作る関数みたいですね。入れる引数はガウシアンのオブジェクト、カメラ情報、出力画像のサイズなどです。
def plot_opacity(gaussian: Gaussian, camera: Camera, w: int, h: int, bitmap: np.ndarray, alphas: np.ndarray):
最後に、これが最終的に絵を出すための関数のようです。中にplot_opacityが連続で呼び出されていて、plot_opacityの中でplot_conics_and_bbsが連続で呼び出されていました。
def plot_model(camera, gaussian_objects):
print('Sorting the gaussians by depth')
indices = np.argsort([gau.get_depth(camera) for gau in gaussian_objects])
print('Plotting with', len(gaussian_objects), 'gaussians')
bitmap = np.zeros((h, w, 3), np.float32)
alphas = np.zeros((h, w), np.float32)
for idx in tqdm(indices):
plot_opacity(gaussian_objects[idx], camera, w, h, bitmap, alphas)
return bitmap
ここまでがファイルの関数などの一式となっていて、これらを使って画像を作ります。
まず、モデルファイルですが、自分でも作れると思いますが、学習ができないので、一般に出回っているファイルを使いました。いろんなところに落ちています。3D Gaussian Splattingのページからこのアルゴリズム向けのplyファイルをダウンロードして使います。
このページのUsageからリンクがありますが、13GBくらいあります。
https://github.com/limacv/GaussianSplattingViewer
正直かなりファイルのサイズが大きいので気をつけてください。ファイルをダウンロードしたら、point_cloud.ply(もしくは名前変更したファイル)を読み込みます。load_plyはutil_gau.pyに入ってます。
model = load_ply('point_cloud_bike.ply')
from tqdm import tqdm
print('Loading gaussians ...')
gaussian_objects = []
for (pos, scale, rot, opacity, sh) in tqdm(zip(model.xyz, model.scale, model.rot, model.opacity, model.sh)):
gaussian_objects.append(Gaussian(pos, scale, rot, opacity, sh))
上記、計算のためにgaussian_objectsのリストにひたすら読み込んだファイルからガウシアンの情報を格納していますね。
最後に画像サイズを指定し、カメラ情報とターゲットとなる座標を決めれば画像が出ます。
plot_modelからbitmapを出力し、表示します。
(h, w) = (720, 1280)
camera = Camera(h, w, position=(-0.57651054, 2.99040512, -0.03924271), target=(-0.0, 0.0, 0.0))
bitmap = plot_model(camera, gaussian_objects)
plt.figure(figsize=(12, 12))
plt.imshow(bitmap, vmin=0, vmax=1.0)
plt.show()
これで3Dガウシアンススプラッティングの実装ができました。ガウシアンの配置や学習は今回のレンダリングのフェーズとは完全に独立して作れますので、それは興味があればそのうちやろうと思います。
今回の計算は見た感じでは完全にCPUで行っているのでかなり時間がかかります。
ガウシアンの読み込みでは、
6131954it [03:49, 26738.95it/s]
600万近くのポイントを読み込むのに、3分49秒かかりました。
あとはガウシアンのソートに多少時間がかかり、その後、プロットに時間がかかっています。
Sorting the gaussians by depth
Plotting with 6131954 gaussians
78%|█████████████████████▋ | 4757265/6131954 [2:19:32<1:25:48, 267.02it/s]
自分のマシンがかなり非力な2016のmacbook12なので、最新マシンならもっと早いと思います。
(まだ画像が出てないので出たら掲載します。)
なんか絵が上手く出なかったので、28万ガウシアンのモデルで勘弁してください。
Loading gaussians ...
281498it [00:19, 14727.80it/s]
Sorting the gaussians by depth
Plotting with 281498 gaussians
100%|██████████████████████████████████| 281498/281498 [23:09<00:00, 202.53it/s]
モデルも大小があるので、小さいモデルなら比較的早く出力されます。以上です。