cupyx.scipy.interpolate.RegularGridInterpolator#

class cupyx.scipy.interpolate.RegularGridInterpolator(points, values, method='linear', bounds_error=True, fill_value=nan)[源码]#

在任意维度的规则或直线网格上进行插值。

数据必须在直线网格上定义;也就是说,是一个等间距或不等间距的矩形网格。支持线性和最近邻插值。设置插值对象后,可以在每次评估时选择插值方法。

参数:
  • points (tuple of ndarray of float, with shapes (m1, ), ..., (mn, )) – 定义 n 维规则网格的点。每个维度中的点(即 points 元组的每个元素)必须严格升序或降序排列。

  • values (ndarray, shape (m1, ..., mn, ...)) – n 维规则网格上的数据。可接受复数数据。

  • method (str, optional) – 要执行的插值方法。支持 "linear", "nearest", "slinear", "cubic", "quintic" 和 "pchip"。此参数将成为对象的 __call__ 方法的默认值。默认为 "linear"。

  • bounds_error (bool, optional) – 如果为 True,当请求插值点超出输入数据域时,将引发 ValueError。如果为 False,则使用 fill_value。默认为 True。

  • fill_value (float or None, optional) – 用于插值域外点的值。如果为 None,则域外的值将被外插。默认为 cp.nan

注意事项

与 scipy 的 LinearNDInterpolatorNearestNDInterpolator 不同,该类通过利用规则网格结构避免了昂贵的输入数据三角剖分。

换句话说,该类假设数据定义在直线网格上。

如果输入数据的维度单位不一致且相差许多数量级,则插值可能存在数值伪影。考虑在插值之前重新缩放数据。

示例

在 3-D 网格点上评估函数

作为第一个例子,我们在 3-D 网格点上评估一个简单的示例函数

>>> from cupyx.scipy.interpolate import RegularGridInterpolator
>>> import cupy as cp
>>> def f(x, y, z):
...     return 2 * x**3 + 3 * y**2 - z
>>> x = cp.linspace(1, 4, 11)
>>> y = cp.linspace(4, 7, 22)
>>> z = cp.linspace(7, 9, 33)
>>> xg, yg ,zg = cp.meshgrid(x, y, z, indexing='ij', sparse=True)
>>> data = f(xg, yg, zg)

data 现在是一个 3-D 数组,其 data[i, j, k] = f(x[i], y[j], z[k])。接下来,从这些数据定义一个插值函数

>>> interp = RegularGridInterpolator((x, y, z), data)

在两个点 (x,y,z) = (2.1, 6.2, 8.3)(3.3, 5.2, 7.1) 处评估插值函数

>>> pts = cp.array([[2.1, 6.2, 8.3],
...                 [3.3, 5.2, 7.1]])
>>> interp(pts)
array([ 125.80469388,  146.30069388])

这确实是以下结果的近似值

>>> f(2.1, 6.2, 8.3), f(3.3, 5.2, 7.1)
(125.54200000000002, 145.894)

插值和外插 2D 数据集

作为第二个例子,我们插值和外插一个 2D 数据集

>>> x, y = cp.array([-2, 0, 4]), cp.array([-2, 0, 2, 5])
>>> def ff(x, y):
...     return x**2 + y**2
>>> xg, yg = cp.meshgrid(x, y, indexing='ij')
>>> data = ff(xg, yg)
>>> interp = RegularGridInterpolator((x, y), data,
...                                  bounds_error=False, fill_value=None)
>>> import matplotlib.pyplot as plt
>>> fig = plt.figure()
>>> ax = fig.add_subplot(projection='3d')
>>> ax.scatter(xg.ravel().get(), yg.ravel().get(), data.ravel().get(),
...            s=60, c='k', label='data')

在更精细的网格上评估并绘制插值器

>>> xx = cp.linspace(-4, 9, 31)
>>> yy = cp.linspace(-4, 9, 31)
>>> X, Y = cp.meshgrid(xx, yy, indexing='ij')
>>> # interpolator
>>> ax.plot_wireframe(X.get(), Y.get(), interp((X, Y)).get(),
                      rstride=3, cstride=3, alpha=0.4, color='m',
                      label='linear interp')
>>> # ground truth
>>> ax.plot_wireframe(X.get(), Y.get(), ff(X, Y).get(),
                      rstride=3, cstride=3,
...                   alpha=0.4, label='ground truth')
>>> plt.legend()
>>> plt.show()

另请参阅

interpn

一个包装 RegularGridInterpolator 的便利函数

scipy.ndimage.map_coordinates

等间距网格上的插值(适用于例如 N 维图像重采样)

参考文献

[1] Python 库 regulargrid 作者 Johannes Buchner,参见

https://pypi.python.org/pypi/regulargrid/

[2] 维基百科,“三线性插值”,

https://en.wikipedia.org/wiki/Trilinear_interpolation

[3] Weiser, Alan, 和 Sergio E. Zarantonello。“关于多维分段线性和多线性表格插值的一点注记。” MATH. COMPUT. 50.181 (1988): 189-196. https://www.ams.org/journals/mcom/1988-50-181/S0025-5718-1988-0917826-0/S0025-5718-1988-0917826-0.pdf

“多维线性和多线性表格插值”。MATH. COMPUT. 50.181 (1988): 189-196. https://www.ams.org/journals/mcom/1988-50-181/S0025-5718-1988-0917826-0/S0025-5718-1988-0917826-0.pdf

方法

__call__(xi, method=None)[源码]#

在给定坐标处进行插值。

参数:
  • xi (cupy.ndarray of shape (..., ndim)) – 要在其中评估插值器的坐标。

  • method (str, optional) – 要执行的插值方法。支持 "linear" 和 "nearest"。默认为创建插值器时选择的方法。

返回:

values_x – 在 xi 处的插值结果。关于 xi.ndim == 1 时的行为请参阅注意事项。

返回类型:

cupy.ndarray, shape xi.shape[:-1] + values.shape[ndim:]

注意事项

如果 xi.ndim == 1,则在返回数组 values_x 的第 0 个位置插入一个新的轴,使其形状变为 (1,) + values.shape[ndim:]

示例

这里我们定义一个简单函数的最近邻插值器

>>> import cupy as cp
>>> x, y = cp.array([0, 1, 2]), cp.array([1, 3, 7])
>>> def f(x, y):
...     return x**2 + y**2
>>> data = f(*cp.meshgrid(x, y, indexing='ij', sparse=True))
>>> from cupyx.scipy.interpolate import RegularGridInterpolator
>>> interp = RegularGridInterpolator((x, y), data, method='nearest')

根据构造,插值器使用最近邻插值

>>> interp([[1.5, 1.3], [0.3, 4.5]])
array([2., 9.])

然而,我们可以通过覆盖 method 参数来评估线性插值

>>> interp([[1.5, 1.3], [0.3, 4.5]], method='linear')
array([ 4.7, 24.3])
__eq__(value, /)#

返回 self==value。

__ne__(value, /)#

返回 self!=value。

__lt__(value, /)#

返回 self<value。

__le__(value, /)#

返回 self<=value。

__gt__(value, /)#

返回 self>value。

__ge__(value, /)#

返回 self>=value。