cupy.diag_indices#

cupy.diag_indices(n, ndim=2)[source]#

返回访问数组主对角线的索引。

返回一个索引元组,可用于访问具有 ndim >= 2 维且形状为 (n, n, …, n) 的数组的主对角线。

参数:
  • n (int) – 要返回索引的数组沿每个维度的大小。

  • ndim (int) – 维度数。默认为 2

示例

创建一个索引集合以访问 (4, 4) 数组的对角线

>>> di = cupy.diag_indices(4)
>>> di
(array([0, 1, 2, 3]), array([0, 1, 2, 3]))
>>> a = cupy.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
>>> a[di] = 100
>>> a
array([[100,   1,   2,   3],
       [  4, 100,   6,   7],
       [  8,   9, 100,  11],
       [ 12,  13,  14, 100]])

创建索引以操作 3-D 数组

>>> d3 = cupy.diag_indices(2, 3)
>>> d3
(array([0, 1]), array([0, 1]), array([0, 1]))

并用它将一个全零数组的对角线设置为 1

>>> a = cupy.zeros((2, 2, 2), dtype=int)
>>> a[d3] = 1
>>> a
array([[[1, 0],
        [0, 0]],

       [[0, 0],
        [0, 1]]])

另请参阅

numpy.diag_indices()