Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 115 additions & 83 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -3215,6 +3215,22 @@
"lineCount": 3
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 17,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 17,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
Expand Down Expand Up @@ -3242,16 +3258,104 @@
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 26,
"startColumn": 15,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 15,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 23,
"endColumn": 30,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 19,
"endColumn": 44,
"startColumn": 23,
"endColumn": 48,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 61,
"endColumn": 70,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 61,
"endColumn": 70,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 72,
"endColumn": 80,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 72,
"endColumn": 80,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 39,
"endColumn": 48,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 39,
"endColumn": 57,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 27,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 47,
"endColumn": 58,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 47,
"endColumn": 58,
"lineCount": 1
}
},
Expand All @@ -3264,7 +3368,7 @@
}
},
{
"code": "reportUnknownArgumentType",
"code": "reportArgumentType",
"range": {
"startColumn": 36,
"endColumn": 45,
Expand Down Expand Up @@ -4714,87 +4818,23 @@
}
},
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 12,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 24,
"endColumn": 34,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
"startColumn": 24,
"endColumn": 34,
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 36,
"endColumn": 46,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
"startColumn": 36,
"endColumn": 46,
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 48,
"endColumn": 58,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
"startColumn": 48,
"endColumn": 58,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 23,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"code": "reportReturnType",
"range": {
"startColumn": 40,
"endColumn": 55,
"lineCount": 1
"startColumn": 19,
"endColumn": 52,
"lineCount": 2
}
},
{
"code": "reportUnknownArgumentType",
"code": "reportArgumentType",
"range": {
"startColumn": 57,
"endColumn": 67,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"code": "reportArgumentType",
"range": {
"startColumn": 69,
"endColumn": 79,
Expand All @@ -4808,14 +4848,6 @@
"endColumn": 80,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 44,
"endColumn": 55,
"lineCount": 1
}
}
],
"./arraycontext/impl/pyopencl/taggable_cl_array.py": [
Expand Down
48 changes: 46 additions & 2 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import numpy as np
from typing_extensions import Self, override

from pytools import memoize_method

from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_container,
Expand All @@ -62,7 +64,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Mapping

from numpy.typing import NDArray
from numpy.typing import DTypeLike, NDArray

import loopy as lp
import pyopencl as cl
Expand Down Expand Up @@ -263,12 +265,54 @@ def to_numpy(self, array: Array) -> np.ndarray:
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

@memoize_method
def _get_to_numpy_noncontiguous_copy_kernel(
Comment thread
kaushikcfd marked this conversation as resolved.
self, dtype: DTypeLike, ndim: int
) -> lp.TranslationUnit:
"""
Returns a translation unit containing a loopy kernel that:

- Accepts a PyOpenCL array ``inp`` with per-axis strides exposed as
``s0, s1, ..., s{ndim-1}``.
- Produces a contiguous, row-major (C-order) output array ``output`` of
the same shape, with elements copied from the corresponding
coordinates in ``input``.
"""

import loopy as lp

from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS

t_unit = lp.make_copy_kernel(
["c"] * ndim, [f"stride:s{i}" for i in range(ndim)]
)
t_unit = lp.add_dtypes(t_unit, {"input": dtype})
new_args = [
*t_unit.default_entrypoint.args,
*[lp.ValueArg(f"s{i}", dtype=np.uint64) for i in range(ndim)],
]
t_unit = t_unit.with_kernel(t_unit.default_entrypoint.copy(args=new_args))
t_unit = lp.set_options(t_unit, _DEFAULT_LOOPY_OPTIONS)
return t_unit

@override
def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
def _to_numpy(ary):
return ary.get(queue=self.queue)
if ary.flags.forc:
# pyopencl supports host transfers only for contiguous arrays.
return ary.get(queue=self.queue)

result = self.call_loopy(
Comment thread
inducer marked this conversation as resolved.
self._get_to_numpy_noncontiguous_copy_kernel(ary.dtype, ary.ndim),
input=ary,
**{
f"s{i}": stride // ary.dtype.itemsize
for i, stride in enumerate(ary.strides)
},
)["output"]
return result.get(queue=self.queue)

return with_array_context(
self._rec_map_container(_to_numpy, array),
Expand Down
32 changes: 31 additions & 1 deletion arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,39 @@ def absolute(self, a):
# {{{ sorting, searching, and counting

def where(self, criterion, then, else_):
def where_inner(inner_crit, inner_then, inner_else):

def where_inner(
inner_crit: ArrayOrScalar,
inner_then: ArrayOrScalar,
inner_else: ArrayOrScalar,
) -> ArrayOrScalar:
if isinstance(inner_crit, bool | np.bool_):
return inner_then if inner_crit else inner_else

# pyopencl's if_positive does not support then, else branches with
# unequal dtypes -> cast them to a common dtype.
Comment thread
kaushikcfd marked this conversation as resolved.
inner_then_dtype = (
inner_then.dtype
if isinstance(inner_then, cl_array.Array)
else np.dtype(type(inner_then))
)
inner_else_dtype = (
inner_else.dtype
if isinstance(inner_else, cl_array.Array)
else np.dtype(type(inner_else))
)
dtype = np.promote_types(inner_then_dtype, inner_else_dtype)
inner_then = (
inner_then.astype(dtype)
if isinstance(inner_then, cl_array.Array)
else dtype.type(inner_then)
)
inner_else = (
inner_else.astype(dtype)
if isinstance(inner_else, cl_array.Array)
else dtype.type(inner_else)
)

return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
queue=self._array_context.queue)

Expand Down
19 changes: 19 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,25 @@ def test_linspace(actx_factory: ArrayContextFactory, args, kwargs):
assert np.allclose(actx_linspace, np_linspace)


# {{{ test_to_numpy_transpose

def test_to_numpy_transpose(actx_factory: ArrayContextFactory):
# fails prior to <https://github.com/inducer/arraycontext/pull/357> for
# pyopencl actx -- cl_array.Array.transpose generates non-contiguous
# arrays requiring non-trivial logic for to host copies.
actx = actx_factory()
rng = np.random.default_rng()
np_ary = rng.random((256, 256, 256))
ary = actx.from_numpy(np_ary)
axis_perm = (0, 2, 1)

np.testing.assert_allclose(
actx.to_numpy(actx.np.transpose(ary, axis_perm)),
np.transpose(np_ary, axis_perm))

# }}}


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down
Loading