관리-도구
편집 파일: linalg.py
from __future__ import annotations from ._dtypes import ( _floating_dtypes, _numeric_dtypes, float32, float64, complex64, complex128 ) from ._manipulation_functions import reshape from ._array_object import Array from ..core.numeric import normalize_axis_tuple from typing import TYPE_CHECKING if TYPE_CHECKING: from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype from typing import NamedTuple import numpy.linalg import numpy as np class EighResult(NamedTuple): eigenvalues: Array eigenvectors: Array class QRResult(NamedTuple): Q: Array R: Array class SlogdetResult(NamedTuple): sign: Array logabsdet: Array class SVDResult(NamedTuple): U: Array S: Array Vh: Array # Note: the inclusion of the upper keyword is different from # np.linalg.cholesky, which does not have it. def cholesky(x: Array, /, *, upper: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.cholesky. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in cholesky') L = np.linalg.cholesky(x._array) if upper: return Array._new(L).mT return Array._new(L) # Note: cross is the numpy top-level namespace, not np.linalg def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`. See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in cross') # Note: this is different from np.cross(), which broadcasts if x1.shape != x2.shape: raise ValueError('x1 and x2 must have the same shape') if x1.ndim == 0: raise ValueError('cross() requires arrays of dimension at least 1') # Note: this is different from np.cross(), which allows dimension 2 if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') return Array._new(np.cross(x1._array, x2._array, axis=axis)) def det(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.det. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in det') return Array._new(np.linalg.det(x._array)) # Note: diagonal is the numpy top-level namespace, not np.linalg def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`. See its docstring for more information. """ # Note: diagonal always operates on the last two axes, whereas np.diagonal # operates on the first two axes by default return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.eigh. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in eigh') # Note: the return type here is a namedtuple, which is different from # np.eigh, which only returns a tuple. return EighResult(*map(Array._new, np.linalg.eigh(x._array))) def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.eigvalsh. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in eigvalsh') return Array._new(np.linalg.eigvalsh(x._array)) def inv(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.inv. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in inv') return Array._new(np.linalg.inv(x._array)) # Note: matmul is the numpy top-level namespace but not in np.linalg def matmul(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. See its docstring for more information. """ # Note: the restriction to numeric dtypes only is different from # np.matmul. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in matmul') return Array._new(np.matmul(x1._array, x2._array)) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.norm. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in matrix_norm') return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) def matrix_power(x: Array, n: int, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.matrix_power. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power') # np.matrix_power already checks if n is an integer return Array._new(np.linalg.matrix_power(x._array, n)) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`. See its docstring for more information. """ # Note: this is different from np.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") S = np.linalg.svd(x._array, compute_uv=False) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps else: if isinstance(rtol, Array): rtol = rtol._array # Note: this is different from np.linalg.matrix_rank, which does not multiply # the tolerance by the largest singular value. tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] return Array._new(np.count_nonzero(S > tol, axis=-1)) # Note: this function is new in the array API spec. Unlike transpose, it only # transposes the last two axes. def matrix_transpose(x: Array, /) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return Array._new(np.swapaxes(x._array, -1, -2)) # Note: outer is the numpy top-level namespace, not np.linalg def outer(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`. See its docstring for more information. """ # Note: the restriction to numeric dtypes only is different from # np.outer. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in outer') # Note: the restriction to only 1-dim arrays is different from np.outer if x1.ndim != 1 or x2.ndim != 1: raise ValueError('The input arrays to outer must be 1-dimensional') return Array._new(np.outer(x1._array, x2._array)) # Note: the keyword argument name rtol is different from np.linalg.pinv def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.pinv. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in pinv') # Note: this is different from np.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps return Array._new(np.linalg.pinv(x._array, rcond=rtol)) def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: """ Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.qr. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in qr') # Note: the return type here is a namedtuple, which is different from # np.linalg.qr, which only returns a tuple. return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) def slogdet(x: Array, /) -> SlogdetResult: """ Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.slogdet. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in slogdet') # Note: the return type here is a namedtuple, which is different from # np.linalg.slogdet, which only returns a tuple. return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array))) # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack # of matrices. The np.linalg.solve behavior of allowing stacks of both # matrices and vectors is ambiguous c.f. # https://github.com/numpy/numpy/issues/15349 and # https://github.com/data-apis/array-api/issues/285. # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. def _solve(a, b): from ..linalg.linalg import (_makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, get_linalg_error_extobj, _raise_linalgerror_singular) from ..linalg import _umath_linalg a, _ = _makearray(a) _assert_stacked_2d(a) _assert_stacked_square(a) b, wrap = _makearray(b) t, result_t = _commonType(a, b) # This part is different from np.linalg.solve if b.ndim == 1: gufunc = _umath_linalg.solve1 else: gufunc = _umath_linalg.solve # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. signature = 'DD->D' if isComplexType(t) else 'dd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) r = gufunc(a, b, signature=signature, extobj=extobj) return wrap(r.astype(result_t, copy=False)) def solve(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.solve. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in solve') return Array._new(_solve(x1._array, x2._array)) def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: """ Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.svd. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svd') # Note: the return type here is a namedtuple, which is different from # np.svd, which only returns a tuple. return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices))) # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: # Note: the restriction to numeric dtypes only is different from # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in tensordot') return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) # Note: trace is the numpy top-level namespace, not np.linalg def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`. See its docstring for more information. """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace() works the same as sum() and prod() (see # _statistical_functions.py) if dtype is None: if x.dtype == float32: dtype = float64 elif x.dtype == complex64: dtype = complex128 # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) if x1_shape[axis] != x2_shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) x1_ = np.moveaxis(x1_, axis, -1) x2_ = np.moveaxis(x2_, axis, -1) res = x1_[..., None, :] @ x2_[..., None] return Array._new(res[..., 0, 0]) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. See its docstring for more information. """ # Note: the restriction to floating-point dtypes only is different from # np.linalg.norm. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in norm') # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done # on a single dimension. a = x._array if axis is None: # Note: np.linalg.norm() doesn't handle 0-D arrays a = a.ravel() _axis = 0 elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # np.linalg.norm() only supports a single axis for vector norm. normalized_axis = normalize_axis_tuple(axis, x.ndim) rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) newshape = axis + rest a = np.transpose(a, newshape).reshape( (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) _axis = 0 else: _axis = axis res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) if keepdims: # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) for i in _axis: shape[i] = 1 res = reshape(res, tuple(shape)) return res __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']