cupyx.GeneralizedUFunc#
- class cupyx.GeneralizedUFunc(func, signature, **kwargs)[source]#
通过使用指定的签名封装用户提供的函数,创建一个广义通用函数 (Generalized Universal Function)。
signature
决定函数是消耗还是生成核心维度。给定输入数组 (*args
) 中剩余的维度被视为循环维度,并且需要能够自然地相互广播。- 参数:
func (callable) – 在输入数组 (
*args
) 上调用,形式为func(*args, **kwargs)
,返回一个数组或数组的元组。如果提供了多个维度不匹配的参数,则此函数应能像 NumPy 通用函数那样,对位置参数的轴进行向量化(广播)。signature (string) – 指定
func
消耗和生成的核维度。遵循 numpy.gufunc signature 的规范。supports_batched (bool, optional) – 如果被封装的函数支持传入包含循环维度和核心维度的完整输入数组。默认为 False。维度将在 GUFunc 处理代码中进行迭代。
supports_out (bool, optional) – 如果被封装的函数支持 out 作为其一个 kwargs。默认为 False。
signatures (list of tuple of str) – 包含形如 'ii->i' 的字符串,其中 i 是 dtype 的字符表示。列表的每个元素都是一个元组,包含该字符串和一个替代 func 的备选函数,当函数的输入可以按照此函数描述的方式进行转换时执行。
name (str, optional) – GUFunc 对象的名称。如果未指定,则使用
func
的名称。doc (str, optional) – GUFunc 对象的文档字符串。如果未指定,则使用
func.__doc__
。
方法
- __call__(*args, **kwargs)[source]#
应用一个广义通用函数。
- 参数:
args – 输入参数。每个参数可以是一个
cupy.ndarray
对象或一个标量。输出参数可以省略,或者由out
参数指定。axes (List of tuples of int, optional) – 一个元组列表,包含广义通用函数应该操作的轴的索引。例如,对于适合矩阵乘法的签名
'(i,j),(j,k)->(i,k)'
,基本元素是二维矩阵,它们被存储在每个参数的最后两个轴中。相应的 axes 关键字将是[(-2, -1), (-2, -1), (-2, -1)]
。为简单起见,对于对一维数组(向量)操作的广义通用函数,可以接受单个整数而不是单元素元组;对于所有输出都是标量的广义通用函数,可以省略输出元组。axis (int, optional) – 广义通用函数应该操作的单个轴。这是对只在单个共享核心维度上操作的通用函数的简化,相当于为每个单核心维度参数传入 (axis,) 作为 axes 条目,为所有其他参数传入
()
。例如,对于签名'(i),(i)->()'
,它等同于传入axes=[(axis,), (axis,), ()]
。keepdims (bool, optional) – 如果设置为 True,则被约简的轴将作为大小为一的维度保留在结果中,以便结果能正确地与输入进行广播。此选项仅适用于核心维度数量相同且输出没有核心维度的广义通用函数,即签名如
'(i),(i)->()'
或'(m,m)->()'
的函数。如果使用此选项,输出中维度的位置可以通过 axes 和 axis 控制。casting (str, optional) – 提供关于允许何种类型转换的策略。默认为
'same_kind'
。dtype (dtype, optional) – 覆盖计算和输出数组的 dtype。类似于 signature。
signature (str or tuple of dtype, optional) – 数据类型、数据类型元组或指示通用函数输入输出类型的特殊签名字符串。此参数允许您为函数提供一个特定的签名,如果该签名已在
__init__
方法的signatures
kwargs 中注册,则使用该签名。如果通用函数不存在指定的循环,则会引发 TypeError。通常,系统会通过比较输入类型与可用类型,并搜索所有输入都可以安全转换的数据类型循环来自动找到合适的循环。此关键字参数允许您绕过该搜索并选择特定的循环。order (str, optional) – 指定输出数组的内存布局。默认为
'K'
。``’C’`` 表示输出应为 C 连续,'F'
表示 F 连续,'A'
表示如果输入为 F 连续且非 C 连续则为 F 连续,否则为 C 连续,'K'
表示尽可能匹配输入的元素顺序。out (cupy.ndarray) – 输出数组。默认情况下,输出到新数组。
- 返回值:
输出数组或输出数组的元组。
- __eq__(value, /)#
返回 self==value。
- __ne__(value, /)#
返回 self!=value。
- __lt__(value, /)#
返回 self<value。
- __le__(value, /)#
返回 self<=value。
- __gt__(value, /)#
返回 self>value。
- __ge__(value, /)#
返回 self>=value。