diff --git a/src/libpython_clj2/python/np_array.clj b/src/libpython_clj2/python/np_array.clj index 9152288..869be5c 100644 --- a/src/libpython_clj2/python/np_array.clj +++ b/src/libpython_clj2/python/np_array.clj @@ -73,12 +73,25 @@ :ctypes ctypes}))) +(defn- zero-copyable-dtype? + "Object/str/datetime arrays hold python-object pointers, not a numeric native + buffer, so only the numeric dtypes in py-dtype->dtype-map can become a tensor." + [pyobj] + (py-ffi/with-gil + (let [dtype-name (-> (py-proto/get-attr pyobj "dtype") + (py-proto/get-attr "name"))] + (contains? py-dtype->dtype-map dtype-name)))) + + (defmethod py-proto/pyobject->jvm :ndarray [pyobj opts] - (pygc/with-stack-context - (-> (numpy->desc pyobj) - (dtt/nd-buffer-descriptor->tensor) - (dtt/clone)))) + (if (zero-copyable-dtype? pyobj) + (pygc/with-stack-context + (-> (numpy->desc pyobj) + (dtt/nd-buffer-descriptor->tensor) + (dtt/clone))) + ;; non-numeric dtype (object/str/datetime): can't zero-copy, fall back to copy + ((get-method py-proto/pyobject->jvm :default) pyobj opts))) (defmethod py-proto/pyobject-as-jvm :ndarray diff --git a/test/libpython_clj2/numpy_test.clj b/test/libpython_clj2/numpy_test.clj index 91d409c..7c32a26 100644 --- a/test/libpython_clj2/numpy_test.clj +++ b/test/libpython_clj2/numpy_test.clj @@ -1,6 +1,8 @@ (ns libpython-clj2.numpy-test (:require [clojure.test :refer [deftest is]] [libpython-clj2.python :as py] + ;; loading these zero-copy bindings is what made object-dtype ->jvm fail + [libpython-clj2.python.np-array] [tech.v3.datatype :as dtype] [tech.v3.datatype.functional :as dfn] [tech.v3.tensor :as dtt])) @@ -18,3 +20,10 @@ (is (dfn/equals (dtt/ensure-tensor np-ary) tens)) (is (dfn/equals [1 2 3 4] (dtype/make-container :java-array :int64 np-ary))))) + + +(deftest object-dtype-ndarray->jvm + (let [empty-obj (py/call-attr-kw np-mod "array" [[]] {:dtype "object"}) + mixed-obj (py/call-attr-kw np-mod "array" [["a" 1]] {:dtype "object"})] + (is (= [] (py/->jvm empty-obj))) + (is (= ["a" 1] (py/->jvm mixed-obj)))))