2323 Type ,
2424 TypeVar ,
2525 Union ,
26+ cast ,
2627)
2728
2829import orjson
4041if TYPE_CHECKING :
4142 import pandas as pd
4243
44+ from docarray .array .doc_vec .doc_vec import DocVec
45+ from docarray .array .doc_vec .io import IOMixinDocVec
4346 from docarray .proto import DocListProto
47+ from docarray .typing .tensor .abstract_tensor import AbstractTensor
4448
45- T = TypeVar ('T' , bound = 'IOMixinArray ' )
49+ T = TypeVar ('T' , bound = 'IOMixinDocList ' )
4650T_doc = TypeVar ('T_doc' , bound = BaseDoc )
4751
4852ARRAY_PROTOCOLS = {'protobuf-array' , 'pickle-array' , 'json-array' }
@@ -96,7 +100,7 @@ def __getitem__(self, item: slice):
96100 return self .content [item ]
97101
98102
99- class IOMixinArray (Iterable [T_doc ]):
103+ class IOMixinDocList (Iterable [T_doc ]):
100104 doc_type : Type [T_doc ]
101105
102106 @abstractmethod
@@ -515,8 +519,6 @@ class Person(BaseDoc):
515519 doc_dict = _access_path_dict_to_nested_dict (access_path2val )
516520 docs .append (doc_type .parse_obj (doc_dict ))
517521
518- if not isinstance (docs , cls ):
519- return cls (docs )
520522 return docs
521523
522524 def to_dataframe (self ) -> 'pd.DataFrame' :
@@ -577,11 +579,13 @@ def _load_binary_all(
577579 protocol : Optional [str ],
578580 compress : Optional [str ],
579581 show_progress : bool ,
582+ tensor_type : Optional [Type ['AbstractTensor' ]] = None ,
580583 ):
581584 """Read a `DocList` object from a binary file
582585 :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
583586 :param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
584587 :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
588+ :param tensor_type: only relevant for DocVec; tensor_type of the DocVec
585589 :return: a `DocList`
586590 """
587591 with file_ctx as fp :
@@ -603,12 +607,20 @@ def _load_binary_all(
603607 proto = cls ._get_proto_class ()()
604608 proto .ParseFromString (d )
605609
606- return cls .from_protobuf (proto )
610+ if tensor_type is not None :
611+ cls_ = cast ('IOMixinDocVec' , cls )
612+ return cls_ .from_protobuf (proto , tensor_type = tensor_type )
613+ else :
614+ return cls .from_protobuf (proto )
607615 elif protocol is not None and protocol == 'pickle-array' :
608616 return pickle .loads (d )
609617
610618 elif protocol is not None and protocol == 'json-array' :
611- return cls .from_json (d )
619+ if tensor_type is not None :
620+ cls_ = cast ('IOMixinDocVec' , cls )
621+ return cls_ .from_json (d , tensor_type = tensor_type )
622+ else :
623+ return cls .from_json (d )
612624
613625 # Binary format for streaming case
614626 else :
@@ -658,6 +670,10 @@ def _load_binary_all(
658670 pbar .update (
659671 t , advance = 1 , total_size = str (filesize .decimal (_total_size ))
660672 )
673+ if tensor_type is not None :
674+ cls__ = cast (Type ['DocVec' ], cls )
675+ # mypy doesn't realize that cls_ is callable
676+ return cls__ (docs , tensor_type = tensor_type ) # type: ignore
661677 return cls (docs )
662678
663679 @classmethod
@@ -724,6 +740,27 @@ def _load_binary_stream(
724740 t , advance = 1 , total_size = str (filesize .decimal (_total_size ))
725741 )
726742
743+ @staticmethod
744+ def _get_file_context (
745+ file : Union [str , bytes , pathlib .Path , io .BufferedReader , _LazyRequestReader ],
746+ protocol : str ,
747+ compress : Optional [str ] = None ,
748+ ) -> Tuple [Union [nullcontext , io .BufferedReader ], Optional [str ], Optional [str ]]:
749+ load_protocol : Optional [str ] = protocol
750+ load_compress : Optional [str ] = compress
751+ file_ctx : Union [nullcontext , io .BufferedReader ]
752+ if isinstance (file , (io .BufferedReader , _LazyRequestReader , bytes )):
753+ file_ctx = nullcontext (file )
754+ # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
755+ elif isinstance (file , (str , pathlib .Path )) and os .path .exists (file ):
756+ load_protocol , load_compress = _protocol_and_compress_from_file_path (
757+ file , protocol , compress
758+ )
759+ file_ctx = open (file , 'rb' )
760+ else :
761+ raise FileNotFoundError (f'cannot find file { file } ' )
762+ return file_ctx , load_protocol , load_compress
763+
727764 @classmethod
728765 def load_binary (
729766 cls : Type [T ],
@@ -753,19 +790,9 @@ def load_binary(
753790 :return: a `DocList` object
754791
755792 """
756- load_protocol : Optional [str ] = protocol
757- load_compress : Optional [str ] = compress
758- file_ctx : Union [nullcontext , io .BufferedReader ]
759- if isinstance (file , (io .BufferedReader , _LazyRequestReader , bytes )):
760- file_ctx = nullcontext (file )
761- # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
762- elif isinstance (file , (str , pathlib .Path )) and os .path .exists (file ):
763- load_protocol , load_compress = _protocol_and_compress_from_file_path (
764- file , protocol , compress
765- )
766- file_ctx = open (file , 'rb' )
767- else :
768- raise FileNotFoundError (f'cannot find file { file } ' )
793+ file_ctx , load_protocol , load_compress = cls ._get_file_context (
794+ file , protocol , compress
795+ )
769796 if streaming :
770797 if load_protocol not in SINGLE_PROTOCOLS :
771798 raise ValueError (
0 commit comments