cupy.get_array_module#

cupy.get_array_module(*args)[source]#

返回参数的数组模块。

此函数用于实现 CPU/GPU 通用代码。如果至少有一个参数是 cupy.ndarray 对象,则返回 cupy 模块。

参数:

args – 用于确定应使用 NumPy 还是 CuPy 的值。

返回:

根据参数的类型返回 cupynumpy

返回类型:

模块

示例

一个 NumPy/CuPy 通用函数可以按如下方式编写

>>> def softplus(x):
...     xp = cupy.get_array_module(x)
...     return xp.maximum(0, x) + xp.log1p(xp.exp(-abs(x)))