import torch
from torch import nn
from torch.quantization import quantize_dynamic
class DemoModel(nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
self.relu = nn.ReLU()
self.fc = nn.Linear(2, 2)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.fc(x)
return x
model_fp32 = DemoModel()
model_int8 = quantize_dynamic(
model=model_fp32,
qconfig_spec={nn.Linear}, # 仅对Linear层进行量化
dtype=torch.qint8
)