diff --git a/scyjava/__init__.py b/scyjava/__init__.py index 53436634..a3af2937 100644 --- a/scyjava/__init__.py +++ b/scyjava/__init__.py @@ -12,6 +12,7 @@ import re import scyjava.config import subprocess +import threading from pathlib import Path from jpype.types import * from _jpype import _JObject @@ -223,6 +224,16 @@ def when_jvm_stops(f): # -- Utility functions -- +_thread_locals = threading.local() + +def get_local(attr: str) -> Any: + if hasattr(_thread_locals, attr): + return getattr(_thread_locals, attr) + raise AttributeError(f"scyjava has no local {attr}, it must be set using scyjava.set_local()") + +def set_local(attr: str, value: Any) -> None: + setattr(_thread_locals, attr, value) + def get_version(java_class): """Return the version of a Java class. """ VersionUtils = jimport('org.scijava.util.VersionUtils') diff --git a/tests/test_locals.py b/tests/test_locals.py new file mode 100644 index 00000000..4d75be11 --- /dev/null +++ b/tests/test_locals.py @@ -0,0 +1,32 @@ +from typing import Any +from scyjava import get_local, set_local +from threading import Thread + +def test_get_local(): + """Ensures that setting a local in one thread does not set it in another.""" + attr = 'foo' + def assert_local(expected_val: Any, equality: bool) -> bool: + try: + actual = expected_val == get_local(attr) + except AttributeError: + actual = False + assert actual == equality + + set_local(attr, 1) + assert_local(1, True) + t: Thread = Thread(target=assert_local, args=[1, False]) + t.start() + t.join() + +def test_set_local(): + """Ensures that setting a local in one thread does not set it in another.""" + attr = 'foo' + def set_local_func(val: Any) -> None: + set_local(attr, val) + + set_local(attr, 1) + t: Thread = Thread(target=set_local_func, args=[2]) + t.start() + t.join() + + assert get_local(attr) == 1