Skip to content

Commit 3cfa0b8

Browse files
authored
fix: fix storage issue in torchtensor class (#1833)
Signed-off-by: Naymul Islam <naymul504@gmail.com>
1 parent 82918fe commit 3cfa0b8

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

docarray/typing/tensor/torch_tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
293293
)
294294
return super().__torch_function__(func, types_, args, kwargs)
295295

296+
def __deepcopy__(self, memo):
297+
"""
298+
Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues.
299+
"""
300+
# Create a new tensor with the same data and properties
301+
new_tensor = self.clone()
302+
# Set the class to the custom TorchTensor class
303+
new_tensor.__class__ = self.__class__
304+
return new_tensor
305+
296306
@classmethod
297307
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
298308
"""Create a `tensor from a numpy array

tests/integrations/typing/test_torch_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
from docarray.typing.tensor.torch_tensor import TorchTensor
3+
import copy
24

35
from docarray import BaseDoc
46
from docarray.typing import TorchEmbedding, TorchTensor
@@ -25,3 +27,19 @@ class MyDocument(BaseDoc):
2527
assert isinstance(d.embedding, TorchEmbedding)
2628
assert isinstance(d.embedding, torch.Tensor)
2729
assert (d.embedding == torch.zeros((128,))).all()
30+
31+
32+
def test_torchtensor_deepcopy():
33+
# Setup
34+
original_tensor_float = TorchTensor(torch.rand(10))
35+
original_tensor_int = TorchTensor(torch.randint(0, 100, (10,)))
36+
37+
# Exercise
38+
copied_tensor_float = copy.deepcopy(original_tensor_float)
39+
copied_tensor_int = copy.deepcopy(original_tensor_int)
40+
41+
# Verify
42+
assert torch.equal(original_tensor_float, copied_tensor_float)
43+
assert original_tensor_float is not copied_tensor_float
44+
assert torch.equal(original_tensor_int, copied_tensor_int)
45+
assert original_tensor_int is not copied_tensor_int

0 commit comments

Comments
 (0)