1- from docarray import DocList , BaseDoc
2- from docarray . typing import AnyTensor
1+ from typing import Any , Dict , List , Optional , Type , Union
2+
33from pydantic import create_model
44from pydantic .fields import FieldInfo
5- from typing import Dict , List , Any , Union , Optional , Type
6- from docarray .utils ._internal ._typing import safe_issubclass
75
6+ from docarray import BaseDoc , DocList
7+ from docarray .typing import AnyTensor
8+ from docarray .utils ._internal ._typing import safe_issubclass
89
910RESERVED_KEYS = [
1011 'type' ,
@@ -71,6 +72,7 @@ def _get_field_type_from_schema(
7172 cached_models : Dict [str , Any ],
7273 is_tensor : bool = False ,
7374 num_recursions : int = 0 ,
75+ definitions : Optional [Dict ] = None ,
7476) -> type :
7577 """
7678 Private method used to extract the corresponding field type from the schema.
@@ -80,8 +82,11 @@ def _get_field_type_from_schema(
8082 :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
8183 :param is_tensor: Boolean used to tell between tensor and list
8284 :param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
85+ :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
8386 :return: A type created from the schema
8487 """
88+ if not definitions :
89+ definitions = {}
8590 field_type = field_schema .get ('type' , None )
8691 tensor_shape = field_schema .get ('tensor/array shape' , None )
8792 ret : Any
@@ -96,6 +101,7 @@ def _get_field_type_from_schema(
96101 root_schema ['definitions' ][ref_name ],
97102 ref_name ,
98103 cached_models = cached_models ,
104+ definitions = definitions ,
99105 )
100106 )
101107 else :
@@ -107,6 +113,7 @@ def _get_field_type_from_schema(
107113 cached_models = cached_models ,
108114 is_tensor = tensor_shape is not None ,
109115 num_recursions = 0 ,
116+ definitions = definitions ,
110117 )
111118 ) # No Union of Lists
112119 ret = Union [tuple (any_of_types )]
@@ -154,19 +161,21 @@ def _get_field_type_from_schema(
154161 if obj_ref :
155162 ref_name = obj_ref .split ('/' )[- 1 ]
156163 ret = create_base_doc_from_schema (
157- root_schema [ ' definitions' ] [ref_name ],
164+ definitions [ref_name ],
158165 ref_name ,
159166 cached_models = cached_models ,
167+ definitions = definitions ,
160168 )
161169 else :
162170 ret = Any
163171 else : # object reference in definitions
164172 if obj_ref :
165173 ref_name = obj_ref .split ('/' )[- 1 ]
166174 doc_type = create_base_doc_from_schema (
167- root_schema [ ' definitions' ] [ref_name ],
175+ definitions [ref_name ],
168176 ref_name ,
169177 cached_models = cached_models ,
178+ definitions = definitions ,
170179 )
171180 ret = DocList [doc_type ]
172181 else :
@@ -182,6 +191,7 @@ def _get_field_type_from_schema(
182191 cached_models = cached_models ,
183192 is_tensor = tensor_shape is not None ,
184193 num_recursions = num_recursions + 1 ,
194+ definitions = definitions ,
185195 )
186196 else :
187197 if num_recursions > 0 :
@@ -196,7 +206,10 @@ def _get_field_type_from_schema(
196206
197207
198208def create_base_doc_from_schema (
199- schema : Dict [str , Any ], base_doc_name : str , cached_models : Optional [Dict ] = None
209+ schema : Dict [str , Any ],
210+ base_doc_name : str ,
211+ cached_models : Optional [Dict ] = None ,
212+ definitions : Optional [Dict ] = None ,
200213) -> Type :
201214 """
202215 Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`.
@@ -230,8 +243,12 @@ class MyDoc(BaseDoc):
230243 :param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents.
231244 :param base_doc_name: The name of the new pydantic model created.
232245 :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
246+ :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
233247 :return: A BaseDoc class dynamically created following the `schema`.
234248 """
249+ if not definitions :
250+ definitions = schema .get ('definitions' , {})
251+
235252 cached_models = cached_models if cached_models is not None else {}
236253 fields : Dict [str , Any ] = {}
237254 if base_doc_name in cached_models :
@@ -245,6 +262,7 @@ class MyDoc(BaseDoc):
245262 cached_models = cached_models ,
246263 is_tensor = False ,
247264 num_recursions = 0 ,
265+ definitions = definitions ,
248266 )
249267 fields [field_name ] = (
250268 field_type ,
0 commit comments