Skip to content

Commit d98acb7

Browse files
authored
fix: fix DocList schema when using Pydantic V2 (#1876)
1 parent 83ebef6 commit d98acb7

File tree

33 files changed

+624
-126
lines changed

33 files changed

+624
-126
lines changed

.github/workflows/cd.yml

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- name: Pre-release (.devN)
2222
run: |
2323
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
24-
pip install poetry
24+
pip install poetry==1.7.1
2525
./scripts/release.sh
2626
env:
2727
PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
@@ -35,20 +35,16 @@ jobs:
3535
steps:
3636
- uses: actions/checkout@v3
3737
with:
38-
fetch-depth: 0
39-
40-
- name: Get changed files
41-
id: changed-files-specific
42-
uses: tj-actions/changed-files@v41
43-
with:
44-
files: |
45-
README.md
38+
fetch-depth: 2
4639

4740
- name: Check if README is modified
4841
id: step_output
49-
if: steps.changed-files-specific.outputs.any_changed == 'true'
5042
run: |
51-
echo "readme_changed=true" >> $GITHUB_OUTPUT
43+
if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then
44+
echo "readme_changed=true" >> $GITHUB_OUTPUT
45+
else
46+
echo "readme_changed=false" >> $GITHUB_OUTPUT
47+
fi
5248
5349
publish-docarray-org:
5450
needs: check-readme-modification

.github/workflows/ci.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
- name: Lint with ruff
2626
run: |
2727
python -m pip install --upgrade pip
28-
python -m pip install poetry
28+
python -m pip install poetry==1.7.1
2929
poetry install
3030
3131
# stop the build if there are Python syntax errors or undefined names
@@ -44,7 +44,7 @@ jobs:
4444
- name: check black
4545
run: |
4646
python -m pip install --upgrade pip
47-
python -m pip install poetry
47+
python -m pip install poetry==1.7.1
4848
poetry install --only dev
4949
poetry run black --check .
5050
@@ -62,7 +62,7 @@ jobs:
6262
- name: Prepare environment
6363
run: |
6464
python -m pip install --upgrade pip
65-
python -m pip install poetry
65+
python -m pip install poetry==1.7.1
6666
poetry install --without dev
6767
poetry run pip install tensorflow==2.12.0
6868
poetry run pip install jax
@@ -106,7 +106,7 @@ jobs:
106106
- name: Prepare environment
107107
run: |
108108
python -m pip install --upgrade pip
109-
python -m pip install poetry
109+
python -m pip install poetry==1.7.1
110110
poetry install --all-extras
111111
poetry run pip install elasticsearch==8.6.2
112112
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
@@ -156,7 +156,7 @@ jobs:
156156
- name: Prepare environment
157157
run: |
158158
python -m pip install --upgrade pip
159-
python -m pip install poetry
159+
python -m pip install poetry==1.7.1
160160
poetry install --all-extras
161161
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
162162
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
@@ -204,7 +204,7 @@ jobs:
204204
- name: Prepare environment
205205
run: |
206206
python -m pip install --upgrade pip
207-
python -m pip install poetry
207+
python -m pip install poetry==1.7.1
208208
poetry install --all-extras
209209
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
210210
poetry run pip install protobuf==3.20.0
@@ -253,7 +253,7 @@ jobs:
253253
- name: Prepare environment
254254
run: |
255255
python -m pip install --upgrade pip
256-
python -m pip install poetry
256+
python -m pip install poetry==1.7.1
257257
poetry install --all-extras
258258
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
259259
poetry run pip install protobuf==3.20.0
@@ -302,7 +302,7 @@ jobs:
302302
- name: Prepare environment
303303
run: |
304304
python -m pip install --upgrade pip
305-
python -m pip install poetry
305+
python -m pip install poetry==1.7.1
306306
poetry install --all-extras
307307
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
308308
poetry run pip install protobuf==3.20.0
@@ -351,7 +351,7 @@ jobs:
351351
- name: Prepare environment
352352
run: |
353353
python -m pip install --upgrade pip
354-
python -m pip install poetry
354+
python -m pip install poetry==1.7.1
355355
poetry install --all-extras
356356
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
357357
poetry run pip uninstall -y torch
@@ -398,7 +398,7 @@ jobs:
398398
- name: Prepare environment
399399
run: |
400400
python -m pip install --upgrade pip
401-
python -m pip install poetry
401+
python -m pip install poetry==1.7.1
402402
poetry install --all-extras
403403
poetry run pip uninstall -y torch
404404
poetry run pip install torch

.github/workflows/ci_only_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
run: |
4444
npm i -g netlify-cli
4545
python -m pip install --upgrade pip
46-
python -m pip install poetry
46+
python -m pip install poetry==1.7.1
4747
python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras
4848
4949
cd docs

docarray/__init__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,60 @@
2020
from docarray.array import DocList, DocVec
2121
from docarray.base_doc.doc import BaseDoc
2222
from docarray.utils._internal.misc import _get_path_from_docarray_root_level
23+
from docarray.utils._internal.pydantic import is_pydantic_v2
24+
25+
26+
def unpickle_doclist(doc_type, b):
27+
return DocList[doc_type].from_bytes(b, protocol="protobuf")
28+
29+
30+
def unpickle_docvec(doc_type, tensor_type, b):
31+
return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type)
32+
33+
34+
if is_pydantic_v2:
35+
# Register the pickle functions
36+
def register_serializers():
37+
import copyreg
38+
from functools import partial
39+
40+
unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf")
41+
42+
def pickle_doc(doc):
43+
b = doc.to_bytes(protocol='protobuf')
44+
return unpickle_doc_fn, (doc.__class__, b)
45+
46+
# Register BaseDoc serialization
47+
copyreg.pickle(BaseDoc, pickle_doc)
48+
49+
# For DocList, we need to hook into __reduce__ since it's a generic
50+
51+
def pickle_doclist(doc_list):
52+
b = doc_list.to_bytes(protocol='protobuf')
53+
doc_type = doc_list.doc_type
54+
return unpickle_doclist, (doc_type, b)
55+
56+
# Replace DocList.__reduce__ with a method that returns the correct format
57+
def doclist_reduce(self):
58+
return pickle_doclist(self)
59+
60+
DocList.__reduce__ = doclist_reduce
61+
62+
# For DocVec, we need to hook into __reduce__ since it's a generic
63+
64+
def pickle_docvec(doc_vec):
65+
b = doc_vec.to_bytes(protocol='protobuf')
66+
doc_type = doc_vec.doc_type
67+
tensor_type = doc_vec.tensor_type
68+
return unpickle_docvec, (doc_type, tensor_type, b)
69+
70+
# Replace DocList.__reduce__ with a method that returns the correct format
71+
def docvec_reduce(self):
72+
return pickle_docvec(self)
73+
74+
DocVec.__reduce__ = docvec_reduce
75+
76+
register_serializers()
2377

2478
__all__ = ['BaseDoc', 'DocList', 'DocVec']
2579

docarray/array/any_array.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from docarray.exceptions.exceptions import UnusableObjectError
2626
from docarray.typing.abstract_type import AbstractType
2727
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
28+
from docarray.utils._internal.pydantic import is_pydantic_v2
2829

2930
if TYPE_CHECKING:
3031
from docarray.proto import DocListProto, NodeProto
@@ -73,8 +74,19 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
7374
# Promote to global scope so multiprocessing can pickle it
7475
global _DocArrayTyped
7576

76-
class _DocArrayTyped(cls): # type: ignore
77-
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
77+
if not is_pydantic_v2:
78+
79+
class _DocArrayTyped(cls): # type: ignore
80+
doc_type: Type[BaseDocWithoutId] = cast(
81+
Type[BaseDocWithoutId], item
82+
)
83+
84+
else:
85+
86+
class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
87+
doc_type: Type[BaseDocWithoutId] = cast(
88+
Type[BaseDocWithoutId], item
89+
)
7890

7991
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
8092

@@ -99,14 +111,24 @@ def _setter(self, value):
99111
setattr(_DocArrayTyped, field, _property_generator(field))
100112
# this generates property on the fly based on the schema of the item
101113

102-
# The global scope and qualname need to refer to this class a unique name.
103-
# Otherwise, creating another _DocArrayTyped will overwrite this one.
104-
change_cls_name(
105-
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
106-
)
107-
108-
cls.__typed_da__[cls][item] = _DocArrayTyped
114+
# # The global scope and qualname need to refer to this class a unique name.
115+
# # Otherwise, creating another _DocArrayTyped will overwrite this one.
116+
if not is_pydantic_v2:
117+
change_cls_name(
118+
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
119+
)
109120

121+
cls.__typed_da__[cls][item] = _DocArrayTyped
122+
else:
123+
change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
124+
if sys.version_info < (3, 12):
125+
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
126+
_DocArrayTyped, item
127+
) # type: ignore
128+
# this do nothing that checking that item is valid type var or str
129+
# Keep the approach in #1147 to be compatible with lower versions of Python.
130+
else:
131+
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore
110132
return cls.__typed_da__[cls][item]
111133

112134
@overload

docarray/array/doc_list/doc_list.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Union,
1313
cast,
1414
overload,
15+
Callable,
1516
)
1617

1718
from pydantic import parse_obj_as
@@ -28,7 +29,6 @@
2829
from docarray.utils._internal.pydantic import is_pydantic_v2
2930

3031
if is_pydantic_v2:
31-
from pydantic import GetCoreSchemaHandler
3232
from pydantic_core import core_schema
3333

3434
from docarray.utils._internal._typing import safe_issubclass
@@ -45,10 +45,7 @@
4545

4646

4747
class DocList(
48-
ListAdvancedIndexing[T_doc],
49-
PushPullMixin,
50-
IOMixinDocList,
51-
AnyDocArray[T_doc],
48+
ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
5249
):
5350
"""
5451
DocList is a container of Documents.
@@ -357,8 +354,20 @@ def __repr__(self):
357354

358355
@classmethod
359356
def __get_pydantic_core_schema__(
360-
cls, _source_type: Any, _handler: GetCoreSchemaHandler
357+
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
361358
) -> core_schema.CoreSchema:
362-
return core_schema.general_plain_validator_function(
363-
cls.validate,
359+
instance_schema = core_schema.is_instance_schema(cls)
360+
args = getattr(source, '__args__', None)
361+
if args:
362+
sequence_t_schema = handler(Sequence[args[0]])
363+
else:
364+
sequence_t_schema = handler(Sequence)
365+
366+
def validate_fn(v, info):
367+
# input has already been validated
368+
return cls(v, validate_input_docs=False)
369+
370+
non_instance_schema = core_schema.with_info_after_validator_function(
371+
validate_fn, sequence_t_schema
364372
)
373+
return core_schema.union_schema([instance_schema, non_instance_schema])

docarray/array/doc_list/io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def to_bytes(
256256
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
257257
:return: the binary serialization in bytes or None if file_ctx is passed where to store
258258
"""
259-
260259
with file_ctx or io.BytesIO() as bf:
261260
self._write_bytes(
262261
bf=bf,

docarray/array/doc_vec/doc_vec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _check_doc_field_not_none(field_name, doc):
198198
if safe_issubclass(tensor.__class__, tensor_type):
199199
field_type = tensor_type
200200

201-
if isinstance(field_type, type):
201+
if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray):
202202
if tf_available and safe_issubclass(field_type, TensorFlowTensor):
203203
# tf.Tensor does not allow item assignment, therefore the
204204
# optimized way
@@ -335,7 +335,9 @@ def _docarray_validate(
335335
return cast(T, value.to_doc_vec())
336336
else:
337337
raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}')
338-
elif isinstance(value, DocList.__class_getitem__(cls.doc_type)):
338+
elif not is_pydantic_v2 and isinstance(
339+
value, DocList.__class_getitem__(cls.doc_type)
340+
):
339341
return cast(T, value.to_doc_vec())
340342
elif isinstance(value, Sequence):
341343
return cls(value)

docarray/base_doc/doc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,13 @@ def _exclude_doclist(
326326
from docarray.array.any_array import AnyDocArray
327327

328328
type_ = self._get_field_annotation(field)
329-
if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
330-
doclist_exclude_fields.append(field)
329+
if is_pydantic_v2:
330+
# Conservative when touching pydantic v1 logic
331+
if safe_issubclass(type_, AnyDocArray):
332+
doclist_exclude_fields.append(field)
333+
else:
334+
if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
335+
doclist_exclude_fields.append(field)
331336

332337
original_exclude = exclude
333338
if exclude is None:
@@ -480,7 +485,6 @@ def model_dump( # type: ignore
480485
warnings: bool = True,
481486
) -> Dict[str, Any]:
482487
def _model_dump(doc):
483-
484488
(
485489
exclude_,
486490
original_exclude,

docarray/base_doc/mixins/update.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
110110
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
111111
field_type = doc._get_field_annotation(field_name)
112112

113-
if isinstance(field_type, type) and safe_issubclass(
114-
field_type, DocList
115-
):
113+
if safe_issubclass(field_type, DocList):
116114
nested_docarray_fields.append(field_name)
117115
else:
118116
origin = get_origin(field_type)

0 commit comments

Comments
 (0)