社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

【PyTorch 奇淫技巧】如何在PyTorch中创建和使用Python自定义操作符

极市平台 • 7 月前 • 455 次点击  
↑ 点击蓝字 关注极市平台
作者丨GiantPandaCV
来源丨GiantPandaCV
编辑丨极市平台

极市导读

 

关于如何在PyTorch中创建和使用Python自定义操作符(Custom Operators)的教程 >>感谢大家对极市直播的支持,由于讲者老师身体不适,原定今晚的直播延迟到下周,届时我们会重新发出直播预告~

前言

在vllm里面看到flash attention包了一层@torch.library.custom_op装饰器(https://github.com/vllm-project/vllm/pull/7536),查阅了一下资料,发现这个是torch 2.4之后的新feature,防止打算torch compile的graph,翻译一下官方教程稍微了解一下这个用法。来源:https://pytorch.org/tutorials/advanced/python_custom_ops.html

Python Custom Operators 教程

这个教程介绍了Python自定义运算符的主题。它列出了我们将从这一教程中学习到的内容,包括如何将用Python编写的自定义运算符与PyTorch集成,以及如何使用torch.library.opcheck来测试自定义运算符。所需的先决条件是安装了PyTorch 2.4或更高版本。

PyTorch提供了大量可以在Tensor上运行的运算符(例如torch.add、torch.sum等)。但是,您可能希望在PyTorch中使用一个新的自定义运算符,可能是由第三方库编写的。本教程展示了如何封装Python函数,使它们的行为类似于PyTorch原生运算符。创建PyTorch中的自定义运算符的原因可能包括:

  • 将任意Python函数视为不透明的可调用对象,与torch.compile相对应(即防止torch.compile跟踪进入函数)。
  • 为任意Python函数添加训练支持。

请注意,如果您的操作可以表示为现有PyTorch运算符的组合,那么通常就不需要使用自定义运算符(例如,支持自动微分的运算应该可以直接工作)。

例子:将PIL库的crop功能封装为一个自定义运算符

假设我们在使用PIL的crop操作

import torch  
from torchvision.transforms.functional import to_pil_image, pil_to_tensor  
import PIL  
import IPython  
import matplotlib.pyplot as plt  
  
def crop(pic, box):  
    img = to_pil_image(pic.cpu())  
    cropped_img = img.crop(box)  
    return pil_to_tensor(cropped_img).to(pic.device) / 255.  
  
def display(img):  
    plt.imshow(img.numpy().transpose((1, 2, 0)))  
  
img = torch.ones(3, 64, 64)  
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)  
display(img)  
cropped_img = crop(img, (10, 10, 50, 50))  
display(cropped_img)  

crop功能无法被torch.compile有效地开箱即用处理:torch.compile在无法处理的函数上会引发"图中断"(https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks),而图中断会导致性能下降。以下代码通过引发错误来演示这一点(如果发生图中断,torch.compile(with fullgraph=True)会引发错误)。

@torch.compile(fullgraph=True)  
def f(img):  
    return crop(img, (10, 10, 50, 50))  
  
# The following raises an error. Uncomment the line to see it.  
# cropped_img = f(img)  

为了能在torch.compile中使用crop作为黑盒操作,我们需要做两件事:

  • 将该函数封装为一个PyTorch自定义运算符。
  • 为该运算符添加"FakeTensor kernel"(又称"meta kernel")。给定输入Tensor的元数据(例如形状),此函数说明如何计算输出Tensor的元数据。
from typing import Sequence  
  
# Use torch.library.custom_op to define a new custom operator.  
# If your operator mutates any input Tensors, their names must be specified  
# in the ``mutates_args`` argument.  
@torch.library.custom_op("mylib::crop", mutates_args=())  
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:  
    img = to_pil_image(pic.cpu())  
    cropped_img = img.crop(box)  
    return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)  
  
# Use register_fake to add a ``FakeTensor`` kernel for the operator  
@crop.register_fake  
def _(pic, box):  
    channels = pic.shape[0]  
    x0, y0, x1, y1 = box  
    return pic.new_empty(channels, y1 - y0, x1 - x0)  

做了上述操作之后,crop现在可以在不产生图中断的情况下正常工作了。

@torch.compile(fullgraph=True)  
def f(img):  
    return crop(img, (10, 10, 50, 50))  
  
cropped_img = f(img)  
display(img)  
display(cropped_img)  

为crop添加训练支持

使用torch.library.register_autograd为运算符添加训练支持。相比直接使用torch.autograd.Function,优先使用这种方式;因为autograd.Function与PyTorch运算符注册API组合使用时,可能会在与torch.compile组合时导致无声的不正确性。crop的梯度公式本质上是PIL.paste(我们把推导留作读者练习)。让我们首先将paste封装为一个自定义运算符:




    
@torch.library.custom_op("mylib::paste", mutates_args=())  
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:  
    assert im1.device == im2.device  
    assert im1.dtype == im2.dtype  
    im1_pil = to_pil_image(im1.cpu())  
    im2_pil = to_pil_image(im2.cpu())  
    PIL.Image.Image.paste(im1_pil, im2_pil, coord)  
    return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)  
  
@paste.register_fake  
def _(im1, im2, coord):  
    assert im1.device == im2.device  
    assert im1.dtype == im2.dtype  
    return torch.empty_like(im1)  

现在让我们使用register_autograd来为crop指定梯度公式:

def backward(ctx, grad_output):  
    grad_input = grad_output.new_zeros(ctx.pic_shape)  
    grad_input = paste(grad_input, grad_output, ctx.coords)  
    return grad_input, None  
  
def setup_context(ctx, inputs, output):  
    pic, box = inputs  
    ctx.coords = box[:2]  
    ctx.pic_shape = pic.shape  
  
crop.register_autograd(backward, setup_context=setup_context)  

注意,backward必须是由PyTorch可理解的运算符组成,这也是我们将paste封装为自定义运算符而不直接使用PIL的paste的原因。

img = img.requires_grad_()  
result = crop(img, (10, 10, 50, 50))  
result.sum().backward()  
display(img.grad)  

这是正确的梯度,在裁剪区域内是1(白色),在未使用的区域内是0(黑色)。

测试Python自定义运算符

使用torch.library.opcheck来测试自定义运算符是否正确注册。这不会测试梯度是否在数学上正确,请单独编写测试(手动测试或使用torch.autograd.gradcheck)。要使用opcheck,请传入一组示例输入用于测试。如果你的运算符支持训练,那么示例应该包括需要计算梯度的Tensor。如果你的运算符支持多个设备,那么示例应该包括来自每个设备的Tensor。

examples = [  
    [torch.randn(3, 64, 64), [0, 0, 10, 10]],  
    [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],  
    [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],  
    [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],  
]  
  
for example in examples:  
    torch.library.opcheck(crop, example)  

可变的Python自定义运算符

你也可以将一个会修改其输入的Python函数封装为自定义运算符。修改输入的函数很常见,因为这是许多low-level kernel编写的方式;例如,计算sin的kernel可能会修改输入,并将输出张量赋值为input.sin()。我们将使用numpy.sin来演示一个可变的Python自定义运算符的示例。

import numpy as np  
  
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")  
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:  
    assert input.device == output.device  
    assert input.device.type == "cpu"  
    input_np = input.numpy()  
    output_np = output.numpy()  
    np.sin(input_np, out=output_np)  

由于该运算符没有返回值,因此不需要注册FakeTensor kernel(元kernel)就可以在torch.compile中正常工作。

@torch.compile(fullgraph=True)  
def f(x):  
    out = torch.empty(3)  
    numpy_sin(x, out)  
    return out  
  
x = torch.randn(3)  
y = f(x)  
assert torch.allclose(y, x.sin())  

这里是一次opcheck运行的结果,告诉我们确实正确注册了该kernel。如果我们忘记添加输出到mutates_args,例如,opcheck将会报错。

总结

在本教程中,我们学习了如何使用torch.library.custom_op创建一个与PyTorch子系统如torch.compileautograd协同工作的Python自定义运算符。

本教程提供了自定义运算符的基本介绍。更多详细信息,请参阅:

  • torch.library文档: https://pytorch.org/docs/stable/library.html
  • 自定义运算符手册: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/173610
 
455 次点击