Skip to content

Commit a764341

Browse files
Joan FontanalsJohannesMessner
andauthored
feat: add method to create BaseDoc from schema (#1667)
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai> Signed-off-by: Joan Fontanals <jfontanalsmartinez@gmail.com> Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com>
1 parent f507a5f commit a764341

File tree

2 files changed

+476
-0
lines changed

2 files changed

+476
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from docarray import DocList, BaseDoc
2+
from docarray.typing import AnyTensor
3+
from pydantic import create_model
4+
from typing import Dict, List, Any, Union, Optional, Type
5+
6+
7+
def create_pure_python_type_model(model: Any) -> BaseDoc:
8+
"""
9+
Take a Pydantic model and cast DocList fields into List fields.
10+
11+
This may be necessary due to limitations in Pydantic:
12+
13+
https://github.com/docarray/docarray/issues/1521
14+
https://github.com/pydantic/pydantic/issues/1457
15+
16+
---
17+
18+
```python
19+
from docarray import BaseDoc
20+
21+
22+
class MyDoc(BaseDoc):
23+
tensor: Optional[AnyTensor]
24+
url: ImageUrl
25+
title: str
26+
texts: DocList[TextDoc]
27+
28+
29+
MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
30+
```
31+
32+
---
33+
:param model: The input model
34+
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
35+
"""
36+
fields: Dict[str, Any] = {}
37+
for field_name, field in model.__annotations__.items():
38+
try:
39+
if issubclass(field, DocList):
40+
t: Any = field.doc_type
41+
fields[field_name] = (List[t], {})
42+
else:
43+
fields[field_name] = (field, {})
44+
except TypeError:
45+
fields[field_name] = (field, {})
46+
return create_model(
47+
model.__name__, __base__=model, __validators__=model.__validators__, **fields
48+
)
49+
50+
51+
def _get_field_type_from_schema(
52+
field_schema: Dict[str, Any],
53+
field_name: str,
54+
root_schema: Dict[str, Any],
55+
cached_models: Dict[str, Any],
56+
is_tensor: bool = False,
57+
num_recursions: int = 0,
58+
) -> type:
59+
"""
60+
Private method used to extract the corresponding field type from the schema.
61+
:param field_schema: The schema from which to extract the type
62+
:param field_name: The name of the field to be created
63+
:param root_schema: The schema of the root object, important to get references
64+
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
65+
:param is_tensor: Boolean used to tell between tensor and list
66+
:param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
67+
:return: A type created from the schema
68+
"""
69+
field_type = field_schema.get('type', None)
70+
tensor_shape = field_schema.get('tensor/array shape', None)
71+
ret: Any
72+
if 'anyOf' in field_schema:
73+
any_of_types = []
74+
for any_of_schema in field_schema['anyOf']:
75+
if '$ref' in any_of_schema:
76+
obj_ref = any_of_schema.get('$ref')
77+
ref_name = obj_ref.split('/')[-1]
78+
any_of_types.append(
79+
create_base_doc_from_schema(
80+
root_schema['definitions'][ref_name],
81+
ref_name,
82+
cached_models=cached_models,
83+
)
84+
)
85+
else:
86+
any_of_types.append(
87+
_get_field_type_from_schema(
88+
any_of_schema,
89+
field_name,
90+
root_schema=root_schema,
91+
cached_models=cached_models,
92+
is_tensor=tensor_shape is not None,
93+
num_recursions=0,
94+
)
95+
) # No Union of Lists
96+
ret = Union[tuple(any_of_types)]
97+
for rec in range(num_recursions):
98+
ret = List[ret]
99+
elif field_type == 'string':
100+
ret = str
101+
for rec in range(num_recursions):
102+
ret = List[ret]
103+
elif field_type == 'integer':
104+
ret = int
105+
for rec in range(num_recursions):
106+
ret = List[ret]
107+
elif field_type == 'number':
108+
if num_recursions <= 1:
109+
# This is a hack because AnyTensor is more generic than a simple List and it comes as simple List
110+
if is_tensor:
111+
ret = AnyTensor
112+
else:
113+
ret = List[float]
114+
else:
115+
ret = float
116+
for rec in range(num_recursions):
117+
ret = List[ret]
118+
elif field_type == 'boolean':
119+
ret = bool
120+
for rec in range(num_recursions):
121+
ret = List[ret]
122+
elif field_type == 'object' or field_type is None:
123+
doc_type: Any
124+
if 'additionalProperties' in field_schema: # handle Dictionaries
125+
additional_props = field_schema['additionalProperties']
126+
if additional_props.get('type') == 'object':
127+
doc_type = create_base_doc_from_schema(
128+
additional_props, field_name, cached_models=cached_models
129+
)
130+
ret = Dict[str, doc_type]
131+
else:
132+
ret = Dict[str, Any]
133+
else:
134+
obj_ref = field_schema.get('$ref') or field_schema.get('allOf', [{}])[
135+
0
136+
].get('$ref', None)
137+
if num_recursions == 0: # single object reference
138+
if obj_ref:
139+
ref_name = obj_ref.split('/')[-1]
140+
ret = create_base_doc_from_schema(
141+
root_schema['definitions'][ref_name],
142+
ref_name,
143+
cached_models=cached_models,
144+
)
145+
else:
146+
ret = Any
147+
else: # object reference in definitions
148+
if obj_ref:
149+
ref_name = obj_ref.split('/')[-1]
150+
doc_type = create_base_doc_from_schema(
151+
root_schema['definitions'][ref_name],
152+
ref_name,
153+
cached_models=cached_models,
154+
)
155+
ret = DocList[doc_type]
156+
else:
157+
doc_type = create_base_doc_from_schema(
158+
field_schema, field_name, cached_models=cached_models
159+
)
160+
ret = DocList[doc_type]
161+
elif field_type == 'array':
162+
ret = _get_field_type_from_schema(
163+
field_schema=field_schema.get('items', {}),
164+
field_name=field_name,
165+
root_schema=root_schema,
166+
cached_models=cached_models,
167+
is_tensor=tensor_shape is not None,
168+
num_recursions=num_recursions + 1,
169+
)
170+
else:
171+
if num_recursions > 0:
172+
raise ValueError(
173+
f"Unknown array item type: {field_type} for field_name {field_name}"
174+
)
175+
else:
176+
raise ValueError(
177+
f"Unknown field type: {field_type} for field_name {field_name}"
178+
)
179+
return ret
180+
181+
182+
def create_base_doc_from_schema(
183+
schema: Dict[str, Any], base_doc_name: str, cached_models: Optional[Dict] = None
184+
) -> Type:
185+
"""
186+
Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`.
187+
188+
This method is intended to dynamically create a `BaseDoc` compatible with the schema
189+
of another BaseDoc. This is useful when that other `BaseDoc` is not available in the current scope. For instance, you may have stored the schema
190+
as a JSON, or sent it to another service, etc.
191+
192+
Due to this Pydantic limitation (https://github.com/docarray/docarray/issues/1521, https://github.com/pydantic/pydantic/issues/1457), we need to make sure that the
193+
input schema uses `List` and not `DocList`. Therefore this is recommended to be used in combination with `create_new_model_cast_doclist_to_list`
194+
to make sure that `DocLists` in schema are converted to `List`.
195+
196+
---
197+
198+
```python
199+
from docarray import BaseDoc
200+
201+
202+
class MyDoc(BaseDoc):
203+
tensor: Optional[AnyTensor]
204+
url: ImageUrl
205+
title: str
206+
texts: DocList[TextDoc]
207+
208+
209+
MyDocCorrected = create_pure_python_type_model(CustomDoc)
210+
new_my_doc_cls = create_base_doc_from_schema(CustomDocCopy.schema(), 'MyDoc')
211+
```
212+
213+
---
214+
:param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents.
215+
:param base_doc_name: The name of the new pydantic model created.
216+
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
217+
:return: A BaseDoc class dynamically created following the `schema`.
218+
"""
219+
cached_models = cached_models if cached_models is not None else {}
220+
fields: Dict[str, Any] = {}
221+
if base_doc_name in cached_models:
222+
return cached_models[base_doc_name]
223+
for field_name, field_schema in schema.get('properties', {}).items():
224+
field_type = _get_field_type_from_schema(
225+
field_schema=field_schema,
226+
field_name=field_name,
227+
root_schema=schema,
228+
cached_models=cached_models,
229+
is_tensor=False,
230+
num_recursions=0,
231+
)
232+
fields[field_name] = (field_type, field_schema.get('description'))
233+
234+
model = create_model(base_doc_name, __base__=BaseDoc, **fields)
235+
cached_models[base_doc_name] = model
236+
return model

0 commit comments

Comments
 (0)