Skip to content

Commit adb0d01

Browse files
authored
fix: fix dynamic class creation with doubly nested schemas (#1747)
Signed-off-by: Alaeddine Abdessalem <alaeddine-13@live.fr>
1 parent 691d939 commit adb0d01

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

docarray/utils/create_dynamic_doc_class.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from docarray import DocList, BaseDoc
2-
from docarray.typing import AnyTensor
1+
from typing import Any, Dict, List, Optional, Type, Union
2+
33
from pydantic import create_model
44
from pydantic.fields import FieldInfo
5-
from typing import Dict, List, Any, Union, Optional, Type
6-
from docarray.utils._internal._typing import safe_issubclass
75

6+
from docarray import BaseDoc, DocList
7+
from docarray.typing import AnyTensor
8+
from docarray.utils._internal._typing import safe_issubclass
89

910
RESERVED_KEYS = [
1011
'type',
@@ -71,6 +72,7 @@ def _get_field_type_from_schema(
7172
cached_models: Dict[str, Any],
7273
is_tensor: bool = False,
7374
num_recursions: int = 0,
75+
definitions: Optional[Dict] = None,
7476
) -> type:
7577
"""
7678
Private method used to extract the corresponding field type from the schema.
@@ -80,8 +82,11 @@ def _get_field_type_from_schema(
8082
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
8183
:param is_tensor: Boolean used to tell between tensor and list
8284
:param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
85+
:param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
8386
:return: A type created from the schema
8487
"""
88+
if not definitions:
89+
definitions = {}
8590
field_type = field_schema.get('type', None)
8691
tensor_shape = field_schema.get('tensor/array shape', None)
8792
ret: Any
@@ -96,6 +101,7 @@ def _get_field_type_from_schema(
96101
root_schema['definitions'][ref_name],
97102
ref_name,
98103
cached_models=cached_models,
104+
definitions=definitions,
99105
)
100106
)
101107
else:
@@ -107,6 +113,7 @@ def _get_field_type_from_schema(
107113
cached_models=cached_models,
108114
is_tensor=tensor_shape is not None,
109115
num_recursions=0,
116+
definitions=definitions,
110117
)
111118
) # No Union of Lists
112119
ret = Union[tuple(any_of_types)]
@@ -154,19 +161,21 @@ def _get_field_type_from_schema(
154161
if obj_ref:
155162
ref_name = obj_ref.split('/')[-1]
156163
ret = create_base_doc_from_schema(
157-
root_schema['definitions'][ref_name],
164+
definitions[ref_name],
158165
ref_name,
159166
cached_models=cached_models,
167+
definitions=definitions,
160168
)
161169
else:
162170
ret = Any
163171
else: # object reference in definitions
164172
if obj_ref:
165173
ref_name = obj_ref.split('/')[-1]
166174
doc_type = create_base_doc_from_schema(
167-
root_schema['definitions'][ref_name],
175+
definitions[ref_name],
168176
ref_name,
169177
cached_models=cached_models,
178+
definitions=definitions,
170179
)
171180
ret = DocList[doc_type]
172181
else:
@@ -182,6 +191,7 @@ def _get_field_type_from_schema(
182191
cached_models=cached_models,
183192
is_tensor=tensor_shape is not None,
184193
num_recursions=num_recursions + 1,
194+
definitions=definitions,
185195
)
186196
else:
187197
if num_recursions > 0:
@@ -196,7 +206,10 @@ def _get_field_type_from_schema(
196206

197207

198208
def create_base_doc_from_schema(
199-
schema: Dict[str, Any], base_doc_name: str, cached_models: Optional[Dict] = None
209+
schema: Dict[str, Any],
210+
base_doc_name: str,
211+
cached_models: Optional[Dict] = None,
212+
definitions: Optional[Dict] = None,
200213
) -> Type:
201214
"""
202215
Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`.
@@ -230,8 +243,12 @@ class MyDoc(BaseDoc):
230243
:param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents.
231244
:param base_doc_name: The name of the new pydantic model created.
232245
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
246+
:param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
233247
:return: A BaseDoc class dynamically created following the `schema`.
234248
"""
249+
if not definitions:
250+
definitions = schema.get('definitions', {})
251+
235252
cached_models = cached_models if cached_models is not None else {}
236253
fields: Dict[str, Any] = {}
237254
if base_doc_name in cached_models:
@@ -245,6 +262,7 @@ class MyDoc(BaseDoc):
245262
cached_models=cached_models,
246263
is_tensor=False,
247264
num_recursions=0,
265+
definitions=definitions,
248266
)
249267
fields[field_name] = (
250268
field_type,

tests/units/util/test_create_dynamic_code_class.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
import numpy as np
14
import pytest
2-
from typing import List, Dict, Union, Any
5+
from pydantic import Field
6+
7+
from docarray import BaseDoc, DocList
8+
from docarray.documents import TextDoc
9+
from docarray.typing import AnyTensor, ImageUrl
310
from docarray.utils.create_dynamic_doc_class import (
411
create_base_doc_from_schema,
512
create_pure_python_type_model,
613
)
7-
import numpy as np
8-
from typing import Optional
9-
from docarray import BaseDoc, DocList
10-
from docarray.typing import AnyTensor, ImageUrl
11-
from docarray.documents import TextDoc
12-
from pydantic import Field
1314

1415

1516
@pytest.mark.parametrize('transformation', ['proto', 'json'])
1617
def test_create_pydantic_model_from_schema(transformation):
18+
class Nested2Doc(BaseDoc):
19+
value: str
20+
21+
class Nested1Doc(BaseDoc):
22+
nested: Nested2Doc
23+
1724
class CustomDoc(BaseDoc):
1825
tensor: Optional[AnyTensor]
1926
url: ImageUrl
@@ -26,6 +33,7 @@ class CustomDoc(BaseDoc):
2633
u: Union[str, int]
2734
lu: List[Union[str, int]] = [0, 1, 2]
2835
tags: Optional[Dict[str, Any]] = None
36+
nested: Nested1Doc
2937

3038
CustomDocCopy = create_pure_python_type_model(CustomDoc)
3139
new_custom_doc_model = create_base_doc_from_schema(
@@ -43,6 +51,7 @@ class CustomDoc(BaseDoc):
4351
single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)),
4452
u='a',
4553
lu=[3, 4],
54+
nested=Nested1Doc(nested=Nested2Doc(value='hello world')),
4655
)
4756
]
4857
)
@@ -77,6 +86,7 @@ class CustomDoc(BaseDoc):
7786
assert custom_partial_da[0].u == 'a'
7887
assert custom_partial_da[0].single_text.text == 'single hey ha'
7988
assert custom_partial_da[0].single_text.embedding.shape == (2,)
89+
assert original_back[0].nested.nested.value == 'hello world'
8090

8191
assert len(original_back) == 1
8292
assert original_back[0].url == 'photo.jpg'

0 commit comments

Comments
 (0)