前言 在vllm里面看到flash attention包了一层@torch.library.custom_op
装饰器(https://github.com/vllm-project/vllm/pull/7536),查阅了一下资料,发现这个是torch 2.4之后的新feature,防止打算torch compile的graph,翻译一下官方教程稍微了解一下这个用法。
来源:https://pytorch.org/tutorials/advanced/python_custom_ops.html
Python Custom Operators 教程 这个教程介绍了Python自定义运算符的主题。它列出了我们将从这一教程中学习到的内容,包括如何将用Python编写的自定义运算符与PyTorch集成,以及如何使用torch.library.opcheck来测试自定义运算符。所需的先决条件是安装了PyTorch 2.4或更高版本。
PyTorch提供了大量可以在Tensor上运行的运算符(例如torch.add、torch.sum等)。但是,您可能希望在PyTorch中使用一个新的自定义运算符,可能是由第三方库编写的。本教程展示了如何封装Python函数,使它们的行为类似于PyTorch原生运算符。创建PyTorch中的自定义运算符的原因可能包括:
将任意Python函数视为不透明的可调用对象,与torch.compile相对应(即防止torch.compile跟踪进入函数)。 请注意,如果您的操作可以表示为现有PyTorch运算符的组合,那么通常就不需要使用自定义运算符(例如,支持自动微分的运算应该可以直接工作)。
例子:将PIL库的crop功能封装为一个自定义运算符 假设我们在使用PIL的crop
操作
import torchfrom torchvision.transforms.functional import to_pil_image, pil_to_tensorimport PILimport IPythonimport matplotlib.pyplot as pltdef crop (pic, box) : img = to_pil_image(pic.cpu()) cropped_img = img.crop(box) return pil_to_tensor(cropped_img).to(pic.device) / 255. def display (img) : plt.imshow(img.numpy().transpose((1 , 2 , 0 )))
img = torch.ones(3 , 64 , 64 ) img *= torch.linspace(0 , 1 , steps=64 ) * torch.linspace(0 , 1 , steps=64 ).unsqueeze(-1 ) display(img)
cropped_img = crop(img, (10 , 10 , 50 , 50 )) display(cropped_img)
crop
功能无法被torch.compile
有效地开箱即用处理:torch.compile
在无法处理的函数上会引发"图中断"(https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks),而图中断会导致性能下降。以下代码通过引发错误来演示这一点(如果发生图中断,torch.compile(with fullgraph=True)
会引发错误)。
@torch.compile(fullgraph=True) def f (img) : return crop(img, (10 , 10 , 50 , 50 ))# The following raises an error. Uncomment the line to see it. # cropped_img = f(img)
为了能在torch.compile
中使用crop
作为黑盒操作,我们需要做两件事:
为该运算符添加"FakeTensor kernel"(又称"meta kernel")。给定输入Tensor的元数据(例如形状),此函数说明如何计算输出Tensor的元数据。 from typing import Sequence# Use torch.library.custom_op to define a new custom operator. # If your operator mutates any input Tensors, their names must be specified # in the ``mutates_args`` argument. @torch.library.custom_op("mylib::crop", mutates_args=()) def crop (pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor: img = to_pil_image(pic.cpu()) cropped_img = img.crop(box) return (pil_to_tensor(cropped_img) / 255. ).to(pic.device, pic.dtype)# Use register_fake to add a ``FakeTensor`` kernel for the operator @crop.register_fake def _ (pic, box) : channels = pic.shape[0 ] x0, y0, x1, y1 = box return pic.new_empty(channels, y1 - y0, x1 - x0)
做了上述操作之后,crop现在可以在不产生图中断的情况下正常工作了。
@torch.compile(fullgraph=True) def f (img) : return crop(img, (10 , 10 , 50 , 50 )) cropped_img = f(img) display(img)
display(cropped_img)
为crop添加训练支持
使用torch.library.register_autograd
为运算符添加训练支持。相比直接使用torch.autograd.Function
,优先使用这种方式;因为autograd.Function
与PyTorch运算符注册API组合使用时,可能会在与torch.compile
组合时导致无声的不正确性。
crop
的梯度公式本质上是PIL.paste
(我们把推导留作读者练习)。让我们首先将paste
封装为一个自定义运算符:
@torch.library.custom_op("mylib::paste", mutates_args=()) def paste (im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor: assert im1.device == im2.device assert im1.dtype == im2.dtype im1_pil = to_pil_image(im1.cpu()) im2_pil = to_pil_image(im2.cpu()) PIL.Image.Image.paste(im1_pil, im2_pil, coord) return (pil_to_tensor(im1_pil) / 255. ).to(im1.device, im1.dtype)@paste.register_fake def _ (im1, im2, coord) : assert
im1.device == im2.device assert im1.dtype == im2.dtype return torch.empty_like(im1)
现在让我们使用register_autograd
来为crop
指定梯度公式:
def backward (ctx, grad_output) : grad_input = grad_output.new_zeros(ctx.pic_shape) grad_input = paste(grad_input, grad_output, ctx.coords) return grad_input, None def setup_context (ctx, inputs, output) : pic, box = inputs ctx.coords = box[:2 ] ctx.pic_shape = pic.shape crop.register_autograd(backward, setup_context=setup_context)
注意,backward必须是由PyTorch可理解的运算符组成,这也是我们将paste封装为自定义运算符而不直接使用PIL的paste的原因。
img = img.requires_grad_() result = crop(img, (10 , 10 , 50 , 50 )) result.sum().backward() display(img.grad)
这是正确的梯度,在裁剪区域内是1(白色),在未使用的区域内是0(黑色)。
测试Python自定义运算符 使用torch.library.opcheck
来测试自定义运算符是否正确注册。这不会测试梯度是否在数学上正确,请单独编写测试(手动测试或使用torch.autograd.gradcheck
)。
要使用opcheck
,请传入一组示例输入用于测试。如果你的运算符支持训练,那么示例应该包括需要计算梯度的Tensor。如果你的运算符支持多个设备,那么示例应该包括来自每个设备的Tensor。
examples = [ [torch.randn(3 , 64 , 64 ), [0 , 0 , 10 , 10 ]], [torch.randn(3 , 91 , 91 , requires_grad=True ), [10 , 0 , 20 , 10 ]], [torch.randn(3 , 60 , 60 , dtype=torch.double), [3 , 4 , 32 , 20 ]], [torch.randn(3 , 512 , 512 , requires_grad=True , dtype=torch.double), [3 , 4 , 32 , 45 ]], ]for example in examples: torch.library.opcheck(crop, example)
可变的Python自定义运算符 你也可以将一个会修改其输入的Python函数封装为自定义运算符。修改输入的函数很常见,因为这是许多low-level kernel编写的方式;例如,计算sin的kernel可能会修改输入,并将输出张量赋值为input.sin()
。
我们将使用numpy.sin
来演示一个可变的Python自定义运算符的示例。
import numpy as np@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu") def numpy_sin (input: torch.Tensor, output: torch.Tensor) -> None : assert input.device == output.device assert input.device.type == "cpu" input_np = input.numpy() output_np = output.numpy() np.sin(input_np, out=output_np)
由于该运算符没有返回值,因此不需要注册FakeTensor kernel(元kernel)就可以在torch.compile
中正常工作。
@torch.compile(fullgraph=True) def f (x) : out = torch.empty(3 ) numpy_sin(x, out) return out x = torch.randn(3 ) y = f(x)assert torch.allclose(y, x.sin())
这里是一次opcheck运行的结果,告诉我们确实正确注册了该kernel。如果我们忘记添加输出到mutates_args
,例如,opcheck
将会报错。
总结 在本教程中,我们学习了如何使用torch.library.custom_op
创建一个与PyTorch子系统如torch.compile
和autograd
协同工作的Python自定义运算符。
本教程提供了自定义运算符的基本介绍。更多详细信息,请参阅:
torch.library文档: https://pytorch.org/docs/stable/library.html 自定义运算符手册: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual