互操作性#
CuPy 可以与其他库协同使用。
NumPy#
cupy.ndarray
实现了 __array_ufunc__
接口(详情请参阅 NEP 13 — Ufuncs 重载机制)。这使得 NumPy 的 ufuncs 可以直接在 CuPy 数组上操作。__array_ufunc__
功能需要 NumPy 1.13 或更高版本。
import cupy
import numpy
arr = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)
result = numpy.sum(arr)
print(type(result)) # => <class 'cupy._core.core.ndarray'>
cupy.ndarray
还实现了 __array_function__
接口(详情请参阅 NEP 18 — NumPy 高级数组函数的调度机制)。这使得使用 NumPy 的代码可以直接在 CuPy 数组上操作。__array_function__
功能需要 NumPy 1.16 或更高版本;从 NumPy 1.17 开始,__array_function__
默认启用。
Numba#
Numba 是一个支持 NumPy 的 Python JIT 编译器。
cupy.ndarray
实现了 __cuda_array_interface__
,这是与 Numba v0.39.0 或更高版本兼容的 CUDA 数组交换接口(详情请参阅 CUDA 数组接口)。这意味着您可以将 CuPy 数组传递给使用 Numba JIT 编译的内核。以下是摘自 numba/numba#2860 的一个简单示例代码:
import cupy
from numba import cuda
@cuda.jit
def add(x, y, out):
start = cuda.grid(1)
stride = cuda.gridsize(1)
for i in range(start, x.shape[0], stride):
out[i] = x[i] + y[i]
a = cupy.arange(10)
b = a * 2
out = cupy.zeros_like(a)
print(out) # => [0 0 0 0 0 0 0 0 0 0]
add[1, 32](a, b, out)
print(out) # => [ 0 3 6 9 12 15 18 21 24 27]
此外,cupy.asarray()
支持从 Numba CUDA 数组到 CuPy 数组的零拷贝转换。
import numpy
import numba
import cupy
x = numpy.arange(10) # type: numpy.ndarray
x_numba = numba.cuda.to_device(x) # type: numba.cuda.cudadrv.devicearray.DeviceNDArray
x_cupy = cupy.asarray(x_numba) # type: cupy.ndarray
警告
__cuda_array_interface__
规定对象的生命周期必须由用户管理,因此如果在被消费者库使用时销毁导出的对象,则会产生未定义行为。
注意
CuPy 使用两个环境变量控制交换行为:CUPY_CUDA_ARRAY_INTERFACE_SYNC
和 CUPY_CUDA_ARRAY_INTERFACE_EXPORT_VERSION
。
mpi4py#
MPI for Python (mpi4py) 是消息传递接口 (MPI) 库的 Python 封装。
MPI 是高性能进程间通信最广泛使用的标准。最近,包括 MPICH、Open MPI 和 MVAPICH 在内的多家 MPI 供应商已将其支持范围扩展到 MPI-3.1 标准之外,以实现“CUDA 感知”;也就是说,直接将 CUDA 设备指针传递给 MPI 调用,从而避免主机和设备之间显式的数据移动。
通过 CuPy 中实现的 __cuda_array_interface__
(如上所述)和 DLPack
数据交换协议(详见下文 DLPack),mpi4py 现在提供(实验性)支持,用于将 CuPy 数组传递给 MPI 调用,前提是 mpi4py 是针对 CUDA 感知的 MPI 实现构建的。以下是摘自 mpi4py 教程 的一个简单示例代码:
# To run this script with N MPI processes, do
# mpiexec -n N python this_script.py
import cupy
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
# Allreduce
sendbuf = cupy.arange(10, dtype='i')
recvbuf = cupy.empty_like(sendbuf)
comm.Allreduce(sendbuf, recvbuf)
assert cupy.allclose(recvbuf, sendbuf*size)
此新功能自 mpi4py 3.1.0 版本添加。更多信息请参阅 mpi4py 网站。
PyTorch#
PyTorch 是一个机器学习框架,提供高性能、可微分的张量运算。
PyTorch 也支持 __cuda_array_interface__
,因此 CuPy 和 PyTorch 之间可以实现零拷贝数据交换,无需额外开销。唯一的注意事项是 PyTorch 默认创建 CPU 张量,这些张量没有定义 __cuda_array_interface__
属性,用户需要在交换之前确保张量已在 GPU 上。
>>> import cupy as cp
>>> import torch
>>>
>>> # convert a torch tensor to a cupy array
>>> a = torch.rand((4, 4), device='cuda')
>>> b = cp.asarray(a)
>>> b *= b
>>> b
array([[0.8215962 , 0.82399917, 0.65607935, 0.30354425],
[0.422695 , 0.8367199 , 0.00208597, 0.18545236],
[0.00226746, 0.46201342, 0.6833052 , 0.47549972],
[0.5208748 , 0.6059282 , 0.1909013 , 0.5148635 ]], dtype=float32)
>>> a
tensor([[0.8216, 0.8240, 0.6561, 0.3035],
[0.4227, 0.8367, 0.0021, 0.1855],
[0.0023, 0.4620, 0.6833, 0.4755],
[0.5209, 0.6059, 0.1909, 0.5149]], device='cuda:0')
>>> # check the underlying memory pointer is the same
>>> assert a.__cuda_array_interface__['data'][0] == b.__cuda_array_interface__['data'][0]
>>>
>>> # convert a cupy array to a torch tensor
>>> a = cp.arange(10)
>>> b = torch.as_tensor(a, device='cuda')
>>> b += 3
>>> b
tensor([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], device='cuda:0')
>>> a
array([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> assert a.__cuda_array_interface__['data'][0] == b.__cuda_array_interface__['data'][0]
PyTorch 也支持通过 DLPack
进行零拷贝数据交换(详见下文 DLPack)
import cupy
import torch
# Create a PyTorch tensor.
tx1 = torch.randn(1, 2, 3, 4).cuda()
# Convert it into a CuPy array.
cx = cupy.from_dlpack(tx1)
# Convert it back to a PyTorch tensor.
tx2 = torch.from_dlpack(cx)
pytorch-pfn-extras 库提供了与 PyTorch 的附加集成功能,包括内存池共享和流共享
>>> import cupy
>>> import torch
>>> import pytorch_pfn_extras as ppe
>>>
>>> # Perform CuPy memory allocation using the PyTorch memory pool.
>>> ppe.cuda.use_torch_mempool_in_cupy()
>>> torch.cuda.memory_allocated()
0
>>> arr = cupy.arange(10)
>>> torch.cuda.memory_allocated()
512
>>>
>>> # Change the default stream in PyTorch and CuPy:
>>> stream = torch.cuda.Stream()
>>> with ppe.cuda.stream(stream):
... ...
在 PyTorch 中使用自定义内核#
借助 DLPack 协议,使用 CuPy 用户定义内核在 PyTorch 中实现函数变得非常简单。下面是使用 cupy.RawKernel
计算对数前向和后向传播的 PyTorch autograd 函数示例。
import cupy
import torch
cupy_custom_kernel_fwd = cupy.RawKernel(
r"""
extern "C" __global__
void cupy_custom_kernel_fwd(const float* x, float* y, int size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < size)
y[tid] = log(x[tid]);
}
""",
"cupy_custom_kernel_fwd",
)
cupy_custom_kernel_bwd = cupy.RawKernel(
r"""
extern "C" __global__
void cupy_custom_kernel_bwd(const float* x, float* gy, float* gx, int size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < size)
gx[tid] = gy[tid] / x[tid];
}
""",
"cupy_custom_kernel_bwd",
)
class CuPyLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.input = x
# Enforce contiguous arrays to simplify RawKernel indexing.
cupy_x = cupy.ascontiguousarray(cupy.from_dlpack(x.detach()))
cupy_y = cupy.empty(cupy_x.shape, dtype=cupy_x.dtype)
x_size = cupy_x.size
bs = 128
cupy_custom_kernel_fwd(
(bs,), ((x_size + bs - 1) // bs,), (cupy_x, cupy_y, x_size)
)
# the ownership of the device memory backing cupy_y is implicitly
# transferred to torch_y, so this operation is safe even after
# going out of scope of this function.
torch_y = torch.from_dlpack(cupy_y)
return torch_y
@staticmethod
def backward(ctx, grad_y):
# Enforce contiguous arrays to simplify RawKernel indexing.
cupy_input = cupy.from_dlpack(ctx.input.detach()).ravel()
cupy_grad_y = cupy.from_dlpack(grad_y.detach()).ravel()
cupy_grad_x = cupy.zeros(cupy_grad_y.shape, dtype=cupy_grad_y.dtype)
gy_size = cupy_grad_y.size
bs = 128
cupy_custom_kernel_bwd(
(bs,),
((gy_size + bs - 1) // bs,),
(cupy_input, cupy_grad_y, cupy_grad_x, gy_size),
)
# the ownership of the device memory backing cupy_grad_x is implicitly
# transferred to torch_y, so this operation is safe even after
# going out of scope of this function.
torch_grad_x = torch.from_dlpack(cupy_grad_x)
return torch_grad_x
注意
将 torch.Tensor
直接馈送给 cupy.from_dlpack()
仅在 CuPy v10+ 和 PyTorch 1.10+ 中新增的 DLPack 数据交换协议中受支持。对于早期版本,您需要使用 torch.utils.dlpack.to_dlpack()
包装 Tensor
,如上例所示。
RMM#
RMM (RAPIDS 内存管理器) 提供高度可配置的内存分配器。
RMM 提供了一个接口,允许 CuPy 从 RMM 内存池而不是 CuPy 自己的池中分配内存。设置方法很简单,例如
import cupy
import rmm
cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
有时,可能需要性能更高的分配器。RMM 提供了一个切换分配器的选项
import cupy
import rmm
rmm.reinitialize(pool_allocator=True) # can also set init pool size etc here
cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
有关 CuPy 内存管理的更多信息,请参阅内存管理。
DLPack#
DLPack 是一个张量结构规范,用于在不同框架之间共享张量。
CuPy 支持从 DLPack 数据结构导入以及导出到 DLPack 数据结构(cupy.from_dlpack()
和 cupy.ndarray.toDlpack()
)。
这是一个简单的例子
import cupy
# Create a CuPy array.
cx1 = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)
# Convert it into a DLPack tensor.
dx = cx1.toDlpack()
# Convert it back to a CuPy array.
cx2 = cupy.from_dlpack(dx)
TensorFlow 也支持 DLPack,因此 CuPy 和 TensorFlow 之间可以通过 DLPack 实现零拷贝数据交换
>>> import tensorflow as tf
>>> import cupy as cp
>>>
>>> # convert a TF tensor to a cupy array
>>> with tf.device('/GPU:0'):
... a = tf.random.uniform((10,))
...
>>> a
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.9672388 , 0.57568085, 0.53163004, 0.6536236 , 0.20479882,
0.84908986, 0.5852566 , 0.30355775, 0.1733712 , 0.9177849 ],
dtype=float32)>
>>> a.device
'/job:localhost/replica:0/task:0/device:GPU:0'
>>> cap = tf.experimental.dlpack.to_dlpack(a)
>>> b = cp.from_dlpack(cap)
>>> b *= 3
>>> b
array([1.4949363 , 0.60699713, 1.3276931 , 1.5781245 , 1.1914308 ,
2.3180873 , 1.9560868 , 1.3932796 , 1.9299742 , 2.5352407 ],
dtype=float32)
>>> a
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([1.4949363 , 0.60699713, 1.3276931 , 1.5781245 , 1.1914308 ,
2.3180873 , 1.9560868 , 1.3932796 , 1.9299742 , 2.5352407 ],
dtype=float32)>
>>>
>>> # convert a cupy array to a TF tensor
>>> a = cp.arange(10)
>>> cap = a.toDlpack()
>>> b = tf.experimental.dlpack.from_dlpack(cap)
>>> b.device
'/job:localhost/replica:0/task:0/device:GPU:0'
>>> b
<tf.Tensor: shape=(10,), dtype=int64, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])>
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
请注意,在 TensorFlow 中,所有张量都是不可变的,因此在后一种情况下,对 b
的任何更改都无法反映到 CuPy 数组 a
中。
请注意,根据 DLPack v0.5 的正确性要求,上述方法(隐式地)要求用户确保此类转换(导入和导出 CuPy 数组)必须在同一 CUDA/HIP 流上发生。如有疑问,例如通过调用 cupy.cuda.get_current_stream()
可以获取当前使用的 CuPy 流。请查阅其他框架的文档,了解如何访问和控制流。
DLPack 数据交换协议#
为了避免用户管理流和 DLPack 张量对象,DLPack 数据交换协议 提供了一种机制,将责任从用户转移到库。任何符合要求的对象(例如 cupy.ndarray
)必须实现一对方法 __dlpack__
和 __dlpack_device__
。函数 cupy.from_dlpack()
接受此类对象,并返回可在 CuPy 当前流上安全访问的 cupy.ndarray
。类似地,cupy.ndarray
可以通过任何符合要求的库的 from_dlpack()
函数导出。
注意
CuPy 使用 CUPY_DLPACK_EXPORT_VERSION
控制如何处理由 CUDA 托管内存支持的张量。
设备内存指针#
导入#
CuPy 提供了 UnownedMemory
API,用于与在其他库中分配的 GPU 设备内存进行互操作。
# Create a memory chunk from raw pointer and its size.
mem = cupy.cuda.UnownedMemory(140359025819648, 1024, owner=None)
# Wrap it as a MemoryPointer.
memptr = cupy.cuda.MemoryPointer(mem, offset=0)
# Create an ndarray view backed by the memory pointer.
arr = cupy.ndarray((16, 16), dtype=cupy.float32, memptr=memptr)
assert arr.nbytes <= arr.data.mem.size
请注意,在创建 ndarray
视图时,您有责任指定正确的形状、数据类型 (dtype)、跨度 (strides) 和顺序 (order),以使其适合内存块。
UnownedMemory
API 不管理内存分配的生命周期。您必须确保指针在使用 CuPy 时仍然有效。如果指针的生命周期由 Python 对象管理,您可以将其传递给 UnownedMemory
的 owner
参数,以保留对该对象的引用。
导出#
您可以将 CuPy 中分配的内存指针传递给其他库。
arr = cupy.arange(10)
print(arr.data.ptr, arr.nbytes) # => (140359025819648, 80)
CuPy 分配的内存会在 ndarray
(arr
) 被销毁时释放。在使用其他库期间,您必须保持 ndarray
实例处于活动状态。
CUDA 流指针#
导入#
CuPy 提供了 ExternalStream
API,用于与在其他库中创建的 CUDA 流进行互操作。
import torch
# Create a stream on PyTorch.
s = torch.cuda.Stream()
# Switch the current stream in PyTorch.
with torch.cuda.stream(s):
# Switch the current stream in CuPy, using the pointer of the stream created in PyTorch.
with cupy.cuda.ExternalStream(s.cuda_stream):
# This block runs on the same CUDA stream.
torch.arange(10, device='cuda')
cupy.arange(10)
ExternalStream
API 不管理流的生命周期。您必须确保流指针在使用 CuPy 时仍然有效。
您还需要确保 ExternalStream
对象在创建流的设备上使用。如果在创建 ExternalStream
时传递 device_id
参数,CuPy 可以为您验证这一点。
导出#
您可以将 CuPy 中创建的流传递给其他库。
s = cupy.cuda.Stream()
print(s.ptr, s.device_id) # => (93997451352336, 0)