common.title

Docs
Quantum Circuit
TYTAN CLOUD

QUANTUM GAMING


Desktop RAG

Overview
Terms of service

Privacy policy

Contact
Research

Sign in
Sign up
common.title

新しい言語モデルGrok-1を動かしてみた。NVIDIA H100ベンチマーク。

Yuichiro Minato

2024/03/21 04:00

イーロンマスクが提供していると話題ですね。

https://grok.x.ai/

3140億パラメータあるようです。

インストールなどはすべてgithubに書いてある通りにしました。

Llama2の70Bよりも多いので、多少インストールに苦労するかもしれません。

https://github.com/xai-org/grok-1

リポジトリをクローンしてフォルダに移動し、huggingfaceからテンソルをダウンロードします。

git clone https://github.com/xai-org/grok-1.git && cd grok-1
pip install huggingface_hub[hf_transfer]
huggingface-cli download xai-org/grok-1 --repo-type model --include ckpt-0/* --local-dir checkpoints --local-dir-use-symlinks False

このテンソルのファイルが結構大きくて、何回かトライしましたが、通信が遅いとダウンロードだけで45分くらいかかりました。最終的には通信が早い環境で計算できたので、10分くらいでダウンロードが終わりました。おそらくこれだけで300GBくらい消費するので、ディスクサイズは大きめに取っておく必要があります。僕は1TB確保しました。500GBくらいでもいいかもしれませんが。

ダウンロードが終わりましたら、Pythonの仮想環境を作り、必要なツールをインストールしました。

どうやらJAXを利用しているようです。

python3 -m venv virt
source virt/bin/activate

必要なパッケージをインストールします。

pip install -r requirements.txt

そして実行です。

python run.py

最初は4GPUで実行すると、すぐにデバイス数でエラーが起きました。

https://github.com/xai-org/grok-1/issues/38

同じような問題を抱えている人は多いみたいで、どうやらJAXを利用していますが、指定のインストールでGPUを使っている場合にはcuda対応のjaxのライブラリをインストール必要があるみたいです。そうでない場合には、deviceの数を(1,1)に指定するとCPUで動作するようです。

せっかくGPU入れてるので、GPUで動かしたいので、

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

で、cuda12対応のjaxlibを入れました。

最初deviceを(1,1)に設定してしまいました。どうやらCPUで動くらしく、ものすごい時間がかかります。あまりに時間がかかるので、GPUにしました。。。

deviceをcpuで動かすには、VRAMは不要ですが、十分なRAM容量が必要です。300GBのテンソルファイルなので、それ以上は必要そうです。

GPUの場合には同様にVRAMは300GB以上必要かもしれません。

ということで、快適にGrokを動かすには相当なスペックが必要そうです。RTX6000adaなどはVRAM48GBです。8枚あれば、384GBで要件は満たしそうです。また、A100 80GやH100などはVRAM80Gなので、8枚使えば合計で640GB確保できるので確実に動きそうです。

GPUの枚数はどうやら標準で8枚使用を想定しているようです。ただ、deviceの設定で枚数は指定できるようですが、その際にはkv headsの値を調整するなどが必要のようです。詳しくはdiscussionボードなどを参考にするのが良さそうです。

jaxlibがインストールされましたら、GPUが利用できそうですので、run.pyのファイルを実行すればいいのですが、使い勝手を良くするために、jupyternotebook形式に変更をして実行しました。

結局色々試したところ、かなりのマシンリソースを必要とするみたいなので、考えるのをやめてH100を8枚で実行しました。

import logging

from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model

CKPT_PATH = "./checkpoints/"

grok_1_model = LanguageModelConfig(
    vocab_size=128 * 1024,
    pad_token=0,
    eos_token=2,
    sequence_len=8192,
    embedding_init_scale=1.0,
    output_multiplier_scale=0.5773502691896257,
    embedding_multiplier_scale=78.38367176906169,
    model=TransformerConfig(
      emb_size=48 * 128,
      widening_factor=8,
      key_size=128,
      num_q_heads=48,
      num_kv_heads=8,
      num_layers=64,
      attn_output_multiplier=0.08838834764831845,
      shard_activations=True,
      # MoE.
      num_experts=8,
      num_selected_experts=2,
      # Activation sharding.
      data_axis="data",
      model_axis="model",
    ),
)

inference_runner = InferenceRunner(
    pad_sizes=(1024,),
    runner=ModelRunner(
      model=grok_1_model,
      bs_per_device=0.125,
      checkpoint_path=CKPT_PATH,
    ),
    name="local",
    load=CKPT_PATH,
    tokenizer_path="./tokenizer.model",
    local_mesh_config=(1, 8),
    between_hosts_config=(1, 1),
)

流石にH100です。爆速でした。まずは初期化のようなものが、

inference_runner.initialize()

多分これ、ものすごい時間がかかるのですが、今回はマシンの性能のおかげで、

1min 37s

1分半で終わりました。

そしてこちらは一瞬で終わりました。

gen = inference_runner.run()

65.6 µs

ここまできたらあとはプロンプトの処理です。デフォルトで入っていたものを使います。

inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))

Output for prompt: The answer to life the universe and everything is of course 42.

The answer to the question of how to get a job in the games industry is not so simple.

I have been asked this question many times over the years and I have always struggled to give a good answer.

I have been in the games industry for over 20 years and I have seen many people come and go. I have seen people with no experience get jobs and I have seen people with years of experience get passed over.

There is no one answer

1min 52s

初回は時間がかかりましたが、二回目以降は結構早かったです。

inp = "what is quantum computer?"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))

Output for prompt: what is quantum computer?

Quantum computers are machines that use the properties of quantum physics to store data and perform computations. This can be extremely advantageous for certain tasks where they could vastly outperform even our best supercomputers.

Classical computers, which include smartphones and laptops, encode information in binary “bits” that can either be 0s or 1s. In a quantum computer, the basic unit of memory is a quantum bit or qubit.

Qubits are made using physical systems, such as the spin of an electron

10.7s程度でした。

お決まりの。

inp = "日本の首都は?"
print(sample_from_model(gen, inp, max_len=100, temperature=0.01))

⇒ 東京 ⇒ 答えは「東京」です。 ※正解の答えが複数ある場合は、そのうちの一つを答えとして採用しています。 ※正解の答えが複数ある場合は、そのうちの一つを答えとして採用しています。 ※正解の答

レシピもトライ。

inp = "トマトソースパスタのレシピを教えて?"
print(sample_from_model(gen, inp, max_len=100, temperature=0.01))

材料はトマトとパスタとオリーブオイルと塩と砂糖と黒胡椒とバジルとにんにくとオニオンとチーズと赤ワインとオリーブと鶏肉とベーコンと牛肉と豚肉と魚介��

レシピはちょっとイマイチかな。日本語はそこまでって感じかもしれません。

次に量子コンピュータのコードを書かせてみます。

inp = "Write a quantum teleportation on qiskit."
print(sample_from_model(gen, inp, max_len=100, temperature=0.01))

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed

tfd = tfp.distributions
tfb = tfp.bijectors

from math import pi, sqrt
from edward2.interceptor import get_interceptor

結果はイマイチでした。

全体的にはまだ学習は細かいところまでは行き届いていないので、とりあえず公開したという感じでしょうか。英語のところはもしかしたらいい感じで行けるのかもしれません。他の人のトライアルに期待をしたいと思います。以上です。

(追記)途中でH100*8で640GBのVRAMでもビデオメモリエラーが起きました。うまく使わないとビデオメモリはかなり消費するようです。

H100 4枚でやってみましたがメモリエラーが出ました。

© 2025, blueqat Inc. All rights reserved