import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class M(nn.Module):
def __init__(self):
super().__init__()
#量子化操作
self.quant = torch.ao.quantization.QuantStub()
#アダマールゲート
self.H = torch.tensor([[1,1],[1,1]])/np.sqrt(2)
#一応使わない予定だけどdequant
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
#einsumで量子計算の予定。。。
x = torch.einsum('a,ab->b',(x, self.quant(self.H)))
return x
これで静的量子化を実行。
Copy
# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
input_fp32 = torch.tensor([1.,0])
res1 = model_fp32_prepared(input_fp32)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
res2 = model_int8(input_fp32)
"currently we don’t have a quantized kernel for einsum, we would be happy to review a PR if someone is interested in implementing. In the meanwhile, a workaround could be to dequantize → floating point einsum → quantize."
class M2(nn.Module):
def __init__(self):
super().__init__()
#量子化操作
self.quant = torch.ao.quantization.QuantStub()
#アダマールゲート
self.H = torch.tensor([[1,1],[1,-1]])/np.sqrt(2)
def forward(self, x):
x = self.quant(x)
H = self.quant(self.H)
#einsumで量子計算の予定。。。
x = torch.tensor([H[0][0]*x[0] + H[0][1]*x[1]])
return x
Copy
# create a model instance
model_fp32 = M2()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
input_fp32 = torch.tensor([1.,0])
res1 = model_fp32_prepared(input_fp32)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
res2 = model_int8(input_fp32)