cupy.get_array_module#
- cupy.get_array_module(*args)[source]#
返回参数的数组模块。
此函数用于实现 CPU/GPU 通用代码。如果至少有一个参数是
cupy.ndarray
对象,则返回cupy
模块。示例
一个 NumPy/CuPy 通用函数可以按如下方式编写
>>> def softplus(x): ... xp = cupy.get_array_module(x) ... return xp.maximum(0, x) + xp.log1p(xp.exp(-abs(x)))