diff --git a/.stats.yml b/.stats.yml index 0b70a4d7..ffcf5e3d 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,3 +1,3 @@ configured_endpoints: 18 -openapi_spec_hash: 153617b7252b1b12f21043b2a1246f8b -config_hash: 30422a4611d93ca69e4f1aff60b9ddb5 +openapi_spec_hash: 539798fac79a1eeebf9ac4faa0492455 +config_hash: 6dcf08c4324405f152d1da9fc11ab04a diff --git a/README.md b/README.md index 4f32c54b..86fc8a0f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI version](https://img.shields.io/pypi/v/openlayer.svg?label=pypi%20(stable))](https://pypi.org/project/openlayer/) -The Openlayer Python library provides convenient access to the Openlayer REST API from any Python 3.8+ +The Openlayer Python library provides convenient access to the Openlayer REST API from any Python 3.9+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -515,7 +515,7 @@ print(openlayer.__version__) ## Requirements -Python 3.8 or higher. +Python 3.9 or higher. ## Contributing diff --git a/pyproject.toml b/pyproject.toml index ac97808a..d5e56ec5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,11 +22,10 @@ dependencies = [ "tqdm", "wrapt>=1.14.0" ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -46,7 +45,7 @@ Homepage = "https://github.com/openlayer-ai/openlayer-python" Repository = "https://github.com/openlayer-ai/openlayer-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true @@ -148,7 +147,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.8" +pythonVersion = "3.9" exclude = [ "_dev", @@ -233,6 +232,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -255,6 +256,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" diff --git a/requirements-dev.lock b/requirements-dev.lock index 157f47cb..6fa69762 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -57,7 +57,7 @@ httpx==0.28.1 # via httpx-aiohttp # via openlayer # via respx -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via openlayer idna==3.4 # via anyio @@ -96,11 +96,9 @@ pluggy==1.5.0 propcache==0.3.2 # via aiohttp # via yarl -pyarrow==15.0.2 +pydantic==2.11.9 # via openlayer -pydantic==2.10.3 - # via openlayer -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic pygments==2.18.0 # via rich @@ -146,6 +144,9 @@ typing-extensions==4.12.2 # via pydantic # via pydantic-core # via pyright + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic tzdata==2024.1 # via pandas urllib3==2.2.3 diff --git a/requirements.lock b/requirements.lock index 5db63e97..e8e4afbe 100644 --- a/requirements.lock +++ b/requirements.lock @@ -44,7 +44,7 @@ httpcore==1.0.9 httpx==0.28.1 # via httpx-aiohttp # via openlayer -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via openlayer idna==3.4 # via anyio @@ -63,11 +63,9 @@ pandas==2.2.2 propcache==0.3.2 # via aiohttp # via yarl -pyarrow==15.0.2 +pydantic==2.11.9 # via openlayer -pydantic==2.10.3 - # via openlayer -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic python-dateutil==2.9.0.post0 # via pandas @@ -92,6 +90,10 @@ typing-extensions==4.12.2 # via openlayer # via pydantic # via pydantic-core + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic +yarl==1.20.0 tzdata==2024.1 # via pandas urllib3==2.2.3 diff --git a/scripts/bootstrap b/scripts/bootstrap index e84fe62c..b430fee3 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -4,10 +4,18 @@ set -e cd "$(dirname "$0")/.." -if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies…" - brew bundle + echo -n "==> Install Homebrew dependencies? (y/N): " + read -r response + case "$response" in + [yY][eE][sS]|[yY]) + brew bundle + ;; + *) + ;; + esac + echo } fi diff --git a/src/openlayer/__init__.py b/src/openlayer/__init__.py index 78f0ca5d..3356befd 100644 --- a/src/openlayer/__init__.py +++ b/src/openlayer/__init__.py @@ -3,7 +3,7 @@ import typing as _t from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given from ._utils import file_from_path from ._client import ( Client, @@ -48,7 +48,9 @@ "ProxiesTypes", "NotGiven", "NOT_GIVEN", + "not_given", "Omit", + "omit", "OpenlayerError", "APIError", "APIStatusError", diff --git a/src/openlayer/_base_client.py b/src/openlayer/_base_client.py index 3e13d930..8e85c020 100644 --- a/src/openlayer/_base_client.py +++ b/src/openlayer/_base_client.py @@ -42,7 +42,6 @@ from ._qs import Querystring from ._files import to_httpx_files, async_to_httpx_files from ._types import ( - NOT_GIVEN, Body, Omit, Query, @@ -57,6 +56,7 @@ RequestOptions, HttpxRequestFiles, ModelBuilderProtocol, + not_given, ) from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping from ._compat import PYDANTIC_V1, model_copy, model_dump @@ -145,9 +145,9 @@ def __init__( def __init__( self, *, - url: URL | NotGiven = NOT_GIVEN, - json: Body | NotGiven = NOT_GIVEN, - params: Query | NotGiven = NOT_GIVEN, + url: URL | NotGiven = not_given, + json: Body | NotGiven = not_given, + params: Query | NotGiven = not_given, ) -> None: self.url = url self.json = json @@ -595,7 +595,7 @@ def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalReques # we internally support defining a temporary header to override the # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` # see _response.py for implementation details - override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given) if is_given(override_cast_to): options.headers = headers return cast(Type[ResponseT], override_cast_to) @@ -825,7 +825,7 @@ def __init__( version: str, base_url: str | URL, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1356,7 +1356,7 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1818,8 +1818,8 @@ def make_request_options( extra_query: Query | None = None, extra_body: Body | None = None, idempotency_key: str | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - post_parser: PostParser | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + post_parser: PostParser | NotGiven = not_given, ) -> RequestOptions: """Create a dict of type RequestOptions without keys of NotGiven values.""" options: RequestOptions = {} diff --git a/src/openlayer/_client.py b/src/openlayer/_client.py index 0ae1918d..04bb9dc2 100644 --- a/src/openlayer/_client.py +++ b/src/openlayer/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Union, Mapping +from typing import Any, Mapping from typing_extensions import Self, override import httpx @@ -11,7 +11,6 @@ from . import _exceptions from ._qs import Querystring from ._types import ( - NOT_GIVEN, Omit, Headers, Timeout, @@ -19,6 +18,7 @@ Transport, ProxiesTypes, RequestOptions, + not_given, ) from ._utils import is_given, get_async_library from ._version import __version__ @@ -62,7 +62,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -149,9 +149,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -245,7 +245,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -332,9 +332,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, diff --git a/src/openlayer/_models.py b/src/openlayer/_models.py index 3a6017ef..fcec2cf9 100644 --- a/src/openlayer/_models.py +++ b/src/openlayer/_models.py @@ -2,6 +2,7 @@ import os import inspect +import weakref from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( @@ -256,7 +257,7 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, @@ -264,6 +265,7 @@ def model_dump( warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, serialize_as_any: bool = False, + fallback: Callable[[Any], Any] | None = None, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -295,10 +297,12 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -313,13 +317,14 @@ def model_dump_json( indent: int | None = None, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -348,11 +353,13 @@ def model_dump_json( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -567,6 +574,9 @@ class CachedDiscriminatorType(Protocol): __discriminator__: DiscriminatorDetails +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -609,8 +619,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached discriminator_field_name: str | None = None @@ -663,7 +674,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + DISCRIMINATOR_CACHE.setdefault(union, details) return details diff --git a/src/openlayer/_qs.py b/src/openlayer/_qs.py index 274320ca..ada6fd3f 100644 --- a/src/openlayer/_qs.py +++ b/src/openlayer/_qs.py @@ -4,7 +4,7 @@ from urllib.parse import parse_qs, urlencode from typing_extensions import Literal, get_args -from ._types import NOT_GIVEN, NotGiven, NotGivenOr +from ._types import NotGiven, not_given from ._utils import flatten _T = TypeVar("_T") @@ -41,8 +41,8 @@ def stringify( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> str: return urlencode( self.stringify_items( @@ -56,8 +56,8 @@ def stringify_items( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> list[tuple[str, str]]: opts = Options( qs=self, @@ -143,8 +143,8 @@ def __init__( self, qs: Querystring = _qs, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> None: self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format diff --git a/src/openlayer/_streaming.py b/src/openlayer/_streaming.py index 8eb34af1..a4a96683 100644 --- a/src/openlayer/_streaming.py +++ b/src/openlayer/_streaming.py @@ -57,9 +57,8 @@ def __stream__(self) -> Iterator[_T]: for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + response.close() def __enter__(self) -> Self: return self @@ -121,9 +120,8 @@ async def __stream__(self) -> AsyncIterator[_T]: async for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + await response.aclose() async def __aenter__(self) -> Self: return self diff --git a/src/openlayer/_types.py b/src/openlayer/_types.py index 8d9dfe1f..c182d468 100644 --- a/src/openlayer/_types.py +++ b/src/openlayer/_types.py @@ -117,18 +117,21 @@ class RequestOptions(TypedDict, total=False): # Sentinel class used until PEP 0661 is accepted class NotGiven: """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). + For parameters with a meaningful None value, we need to distinguish between + the user explicitly passing None, and the user not passing the parameter at + all. + + User code shouldn't need to use not_given directly. For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... + def create(timeout: Timeout | None | NotGiven = not_given): ... - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. + create(timeout=1) # 1s timeout + create(timeout=None) # No timeout + create() # Default timeout behavior ``` """ @@ -140,13 +143,14 @@ def __repr__(self) -> str: return "NOT_GIVEN" -NotGivenOr = Union[_T, NotGiven] +not_given = NotGiven() +# for backwards compatibility: NOT_GIVEN = NotGiven() class Omit: - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: + """ + To explicitly omit something from being sent in a request, use `omit`. ```py # as the default `Content-Type` header is `application/json` that will be sent @@ -156,8 +160,8 @@ class Omit: # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' client.post(..., headers={"Content-Type": "multipart/form-data"}) - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={"Content-Type": Omit()}) + # instead you can remove the default `application/json` header by passing omit + client.post(..., headers={"Content-Type": omit}) ``` """ @@ -165,6 +169,9 @@ def __bool__(self) -> Literal[False]: return False +omit = Omit() + + @runtime_checkable class ModelBuilderProtocol(Protocol): @classmethod diff --git a/src/openlayer/_utils/_sync.py b/src/openlayer/_utils/_sync.py index ad7ec71b..f6027c18 100644 --- a/src/openlayer/_utils/_sync.py +++ b/src/openlayer/_utils/_sync.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import asyncio import functools -import contextvars -from typing import Any, TypeVar, Callable, Awaitable +from typing import TypeVar, Callable, Awaitable from typing_extensions import ParamSpec import anyio @@ -15,34 +13,11 @@ T_ParamSpec = ParamSpec("T_ParamSpec") -if sys.version_info >= (3, 9): - _asyncio_to_thread = asyncio.to_thread -else: - # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - # for Python 3.8 support - async def _asyncio_to_thread( - func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs - ) -> Any: - """Asynchronously run function *func* in a separate thread. - - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. - - Returns a coroutine that can be awaited to get the eventual result of *func*. - """ - loop = asyncio.events.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - async def to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> T_Retval: if sniffio.current_async_library() == "asyncio": - return await _asyncio_to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) return await anyio.to_thread.run_sync( functools.partial(func, *args, **kwargs), @@ -53,10 +28,7 @@ async def to_thread( def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments. For python version 3.9 and above, it uses - asyncio.to_thread to run the function in a separate thread. For python version - 3.8, it uses locally defined copy of the asyncio.to_thread function which was - introduced in python 3.9. + positional and keyword arguments. Usage: diff --git a/src/openlayer/_utils/_transform.py b/src/openlayer/_utils/_transform.py index c19124f0..52075492 100644 --- a/src/openlayer/_utils/_transform.py +++ b/src/openlayer/_utils/_transform.py @@ -268,7 +268,7 @@ def _transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue @@ -434,7 +434,7 @@ async def _async_transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue diff --git a/src/openlayer/_utils/_utils.py b/src/openlayer/_utils/_utils.py index f0818595..eec7f4a1 100644 --- a/src/openlayer/_utils/_utils.py +++ b/src/openlayer/_utils/_utils.py @@ -21,7 +21,7 @@ import sniffio -from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._types import Omit, NotGiven, FileTypes, HeadersLike _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -63,7 +63,7 @@ def _extract_items( try: key = path[index] except IndexError: - if isinstance(obj, NotGiven): + if not is_given(obj): # no value was provided - we can safely ignore return [] @@ -126,14 +126,14 @@ def _extract_items( return [] -def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: - return not isinstance(obj, NotGiven) +def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) and not isinstance(obj, Omit) # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input diff --git a/src/openlayer/resources/commits/commits.py b/src/openlayer/resources/commits/commits.py index 64ae8377..df43c6e2 100644 --- a/src/openlayer/resources/commits/commits.py +++ b/src/openlayer/resources/commits/commits.py @@ -4,7 +4,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Query, Headers, NotGiven, not_given from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -60,7 +60,7 @@ def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitRetrieveResponse: """ Retrieve a project version (commit) by its id. @@ -118,7 +118,7 @@ async def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitRetrieveResponse: """ Retrieve a project version (commit) by its id. diff --git a/src/openlayer/resources/commits/test_results.py b/src/openlayer/resources/commits/test_results.py index b9b6e70a..4c848588 100644 --- a/src/openlayer/resources/commits/test_results.py +++ b/src/openlayer/resources/commits/test_results.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -49,17 +49,17 @@ def list( self, project_version_id: str, *, - include_archived: bool | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - status: Literal["running", "passing", "failing", "skipped", "error"] | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, + include_archived: bool | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, + status: Literal["running", "passing", "failing", "skipped", "error"] | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestResultListResponse: """ List the test results for a project commit (project version). @@ -133,17 +133,17 @@ async def list( self, project_version_id: str, *, - include_archived: bool | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - status: Literal["running", "passing", "failing", "skipped", "error"] | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, + include_archived: bool | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, + status: Literal["running", "passing", "failing", "skipped", "error"] | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestResultListResponse: """ List the test results for a project commit (project version). diff --git a/src/openlayer/resources/inference_pipelines/data.py b/src/openlayer/resources/inference_pipelines/data.py index 58af5086..9d2b6370 100644 --- a/src/openlayer/resources/inference_pipelines/data.py +++ b/src/openlayer/resources/inference_pipelines/data.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Query, Headers, NotGiven, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -54,7 +54,7 @@ def stream( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DataStreamResponse: """ Publish an inference data point to an inference pipeline. @@ -124,7 +124,7 @@ async def stream( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DataStreamResponse: """ Publish an inference data point to an inference pipeline. diff --git a/src/openlayer/resources/inference_pipelines/inference_pipelines.py b/src/openlayer/resources/inference_pipelines/inference_pipelines.py index c9c29f5c..ece56525 100644 --- a/src/openlayer/resources/inference_pipelines/inference_pipelines.py +++ b/src/openlayer/resources/inference_pipelines/inference_pipelines.py @@ -24,7 +24,7 @@ AsyncRowsResourceWithStreamingResponse, ) from ...types import inference_pipeline_update_params, inference_pipeline_retrieve_params -from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from ..._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -85,13 +85,13 @@ def retrieve( self, inference_pipeline_id: str, *, - expand: List[Literal["project", "workspace"]] | NotGiven = NOT_GIVEN, + expand: List[Literal["project", "workspace"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineRetrieveResponse: """ Retrieve inference pipeline. @@ -129,15 +129,15 @@ def update( self, inference_pipeline_id: str, *, - description: Optional[str] | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, - reference_dataset_uri: Optional[str] | NotGiven = NOT_GIVEN, + description: Optional[str] | Omit = omit, + name: str | Omit = omit, + reference_dataset_uri: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineUpdateResponse: """ Update inference pipeline. @@ -187,7 +187,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete inference pipeline. @@ -251,13 +251,13 @@ async def retrieve( self, inference_pipeline_id: str, *, - expand: List[Literal["project", "workspace"]] | NotGiven = NOT_GIVEN, + expand: List[Literal["project", "workspace"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineRetrieveResponse: """ Retrieve inference pipeline. @@ -295,15 +295,15 @@ async def update( self, inference_pipeline_id: str, *, - description: Optional[str] | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, - reference_dataset_uri: Optional[str] | NotGiven = NOT_GIVEN, + description: Optional[str] | Omit = omit, + name: str | Omit = omit, + reference_dataset_uri: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineUpdateResponse: """ Update inference pipeline. @@ -353,7 +353,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete inference pipeline. diff --git a/src/openlayer/resources/inference_pipelines/rows.py b/src/openlayer/resources/inference_pipelines/rows.py index c6358556..0c77dfb1 100644 --- a/src/openlayer/resources/inference_pipelines/rows.py +++ b/src/openlayer/resources/inference_pipelines/rows.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -49,13 +49,13 @@ def update( *, inference_id: str, row: object, - config: Optional[row_update_params.Config] | NotGiven = NOT_GIVEN, + config: Optional[row_update_params.Config] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RowUpdateResponse: """ Update an inference data point in an inference pipeline. @@ -121,13 +121,13 @@ async def update( *, inference_id: str, row: object, - config: Optional[row_update_params.Config] | NotGiven = NOT_GIVEN, + config: Optional[row_update_params.Config] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RowUpdateResponse: """ Update an inference data point in an inference pipeline. diff --git a/src/openlayer/resources/inference_pipelines/test_results.py b/src/openlayer/resources/inference_pipelines/test_results.py index c4c87494..5344d554 100644 --- a/src/openlayer/resources/inference_pipelines/test_results.py +++ b/src/openlayer/resources/inference_pipelines/test_results.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -49,16 +49,16 @@ def list( self, inference_pipeline_id: str, *, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - status: Literal["running", "passing", "failing", "skipped", "error"] | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, + page: int | Omit = omit, + per_page: int | Omit = omit, + status: Literal["running", "passing", "failing", "skipped", "error"] | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestResultListResponse: """ List the latest test results for an inference pipeline. @@ -131,16 +131,16 @@ async def list( self, inference_pipeline_id: str, *, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - status: Literal["running", "passing", "failing", "skipped", "error"] | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, + page: int | Omit = omit, + per_page: int | Omit = omit, + status: Literal["running", "passing", "failing", "skipped", "error"] | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestResultListResponse: """ List the latest test results for an inference pipeline. diff --git a/src/openlayer/resources/projects/commits.py b/src/openlayer/resources/projects/commits.py index bec55f37..381d8b2d 100644 --- a/src/openlayer/resources/projects/commits.py +++ b/src/openlayer/resources/projects/commits.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -50,14 +50,14 @@ def create( *, commit: commit_create_params.Commit, storage_uri: str, - archived: Optional[bool] | NotGiven = NOT_GIVEN, - deployment_status: str | NotGiven = NOT_GIVEN, + archived: Optional[bool] | Omit = omit, + deployment_status: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitCreateResponse: """ Create a new commit (project version) in a project. @@ -102,14 +102,14 @@ def list( self, project_id: str, *, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + page: int | Omit = omit, + per_page: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitListResponse: """ List the commits (project versions) in a project. @@ -174,14 +174,14 @@ async def create( *, commit: commit_create_params.Commit, storage_uri: str, - archived: Optional[bool] | NotGiven = NOT_GIVEN, - deployment_status: str | NotGiven = NOT_GIVEN, + archived: Optional[bool] | Omit = omit, + deployment_status: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitCreateResponse: """ Create a new commit (project version) in a project. @@ -226,14 +226,14 @@ async def list( self, project_id: str, *, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + page: int | Omit = omit, + per_page: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CommitListResponse: """ List the commits (project versions) in a project. diff --git a/src/openlayer/resources/projects/inference_pipelines.py b/src/openlayer/resources/projects/inference_pipelines.py index c380a19a..fbc1b9c5 100644 --- a/src/openlayer/resources/projects/inference_pipelines.py +++ b/src/openlayer/resources/projects/inference_pipelines.py @@ -6,7 +6,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -50,14 +50,15 @@ def create( *, description: Optional[str], name: str, - project: Optional[inference_pipeline_create_params.Project] | NotGiven = NOT_GIVEN, - workspace: Optional[inference_pipeline_create_params.Workspace] | NotGiven = NOT_GIVEN, + data_backend: Optional[inference_pipeline_create_params.DataBackend] | Omit = omit, + project: Optional[inference_pipeline_create_params.Project] | Omit = omit, + workspace: Optional[inference_pipeline_create_params.Workspace] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineCreateResponse: """ Create an inference pipeline in a project. @@ -83,6 +84,7 @@ def create( { "description": description, "name": name, + "data_backend": data_backend, "project": project, "workspace": workspace, }, @@ -98,15 +100,15 @@ def list( self, project_id: str, *, - name: str | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + name: str | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineListResponse: """ List the inference pipelines in a project. @@ -174,14 +176,15 @@ async def create( *, description: Optional[str], name: str, - project: Optional[inference_pipeline_create_params.Project] | NotGiven = NOT_GIVEN, - workspace: Optional[inference_pipeline_create_params.Workspace] | NotGiven = NOT_GIVEN, + data_backend: Optional[inference_pipeline_create_params.DataBackend] | Omit = omit, + project: Optional[inference_pipeline_create_params.Project] | Omit = omit, + workspace: Optional[inference_pipeline_create_params.Workspace] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineCreateResponse: """ Create an inference pipeline in a project. @@ -207,6 +210,7 @@ async def create( { "description": description, "name": name, + "data_backend": data_backend, "project": project, "workspace": workspace, }, @@ -222,15 +226,15 @@ async def list( self, project_id: str, *, - name: str | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + name: str | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InferencePipelineListResponse: """ List the inference pipelines in a project. diff --git a/src/openlayer/resources/projects/projects.py b/src/openlayer/resources/projects/projects.py index c19b911f..c5aba51c 100644 --- a/src/openlayer/resources/projects/projects.py +++ b/src/openlayer/resources/projects/projects.py @@ -24,7 +24,7 @@ CommitsResourceWithStreamingResponse, AsyncCommitsResourceWithStreamingResponse, ) -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -86,13 +86,13 @@ def create( *, name: str, task_type: Literal["llm-base", "tabular-classification", "tabular-regression", "text-classification"], - description: Optional[str] | NotGiven = NOT_GIVEN, + description: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ProjectCreateResponse: """ Create a project in your workspace. @@ -131,17 +131,17 @@ def create( def list( self, *, - name: str | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + name: str | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, task_type: Literal["llm-base", "tabular-classification", "tabular-regression", "text-classification"] - | NotGiven = NOT_GIVEN, + | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ProjectListResponse: """ List your workspace's projects. @@ -221,13 +221,13 @@ async def create( *, name: str, task_type: Literal["llm-base", "tabular-classification", "tabular-regression", "text-classification"], - description: Optional[str] | NotGiven = NOT_GIVEN, + description: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ProjectCreateResponse: """ Create a project in your workspace. @@ -266,17 +266,17 @@ async def create( async def list( self, *, - name: str | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, + name: str | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, task_type: Literal["llm-base", "tabular-classification", "tabular-regression", "text-classification"] - | NotGiven = NOT_GIVEN, + | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ProjectListResponse: """ List your workspace's projects. diff --git a/src/openlayer/resources/projects/tests.py b/src/openlayer/resources/projects/tests.py index a795c811..ed102a19 100644 --- a/src/openlayer/resources/projects/tests.py +++ b/src/openlayer/resources/projects/tests.py @@ -7,7 +7,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -99,20 +99,24 @@ def create( ], thresholds: Iterable[test_create_params.Threshold], type: Literal["integrity", "consistency", "performance"], - archived: bool | NotGiven = NOT_GIVEN, - delay_window: Optional[float] | NotGiven = NOT_GIVEN, - evaluation_window: Optional[float] | NotGiven = NOT_GIVEN, - uses_ml_model: bool | NotGiven = NOT_GIVEN, - uses_production_data: bool | NotGiven = NOT_GIVEN, - uses_reference_dataset: bool | NotGiven = NOT_GIVEN, - uses_training_dataset: bool | NotGiven = NOT_GIVEN, - uses_validation_dataset: bool | NotGiven = NOT_GIVEN, + archived: bool | Omit = omit, + default_to_all_pipelines: Optional[bool] | Omit = omit, + delay_window: Optional[float] | Omit = omit, + evaluation_window: Optional[float] | Omit = omit, + exclude_pipelines: Optional[SequenceNotStr[str]] | Omit = omit, + include_historical_data: Optional[bool] | Omit = omit, + include_pipelines: Optional[SequenceNotStr[str]] | Omit = omit, + uses_ml_model: bool | Omit = omit, + uses_production_data: bool | Omit = omit, + uses_reference_dataset: bool | Omit = omit, + uses_training_dataset: bool | Omit = omit, + uses_validation_dataset: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestCreateResponse: """ Create a test. @@ -128,11 +132,23 @@ def create( archived: Whether the test is archived. + default_to_all_pipelines: Whether to apply the test to all pipelines (data sources) or to a specific set + of pipelines. Only applies to tests that use production data. + delay_window: The delay window in seconds. Only applies to tests that use production data. evaluation_window: The evaluation window in seconds. Only applies to tests that use production data. + exclude_pipelines: Array of pipelines (data sources) to which the test should not be applied. Only + applies to tests that use production data. + + include_historical_data: Whether to include historical data in the test result. Only applies to tests + that use production data. + + include_pipelines: Array of pipelines (data sources) to which the test should be applied. Only + applies to tests that use production data. + uses_ml_model: Whether the test uses an ML model. uses_production_data: Whether the test uses production data (monitoring mode only). @@ -163,8 +179,12 @@ def create( "thresholds": thresholds, "type": type, "archived": archived, + "default_to_all_pipelines": default_to_all_pipelines, "delay_window": delay_window, "evaluation_window": evaluation_window, + "exclude_pipelines": exclude_pipelines, + "include_historical_data": include_historical_data, + "include_pipelines": include_pipelines, "uses_ml_model": uses_ml_model, "uses_production_data": uses_production_data, "uses_reference_dataset": uses_reference_dataset, @@ -189,7 +209,7 @@ def update( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestUpdateResponse: """ Update tests. @@ -218,19 +238,19 @@ def list( self, project_id: str, *, - include_archived: bool | NotGiven = NOT_GIVEN, - origin_version_id: Optional[str] | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - suggested: bool | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, - uses_production_data: Optional[bool] | NotGiven = NOT_GIVEN, + include_archived: bool | Omit = omit, + origin_version_id: Optional[str] | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, + suggested: bool | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, + uses_production_data: Optional[bool] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestListResponse: """ List tests under a project. @@ -356,20 +376,24 @@ async def create( ], thresholds: Iterable[test_create_params.Threshold], type: Literal["integrity", "consistency", "performance"], - archived: bool | NotGiven = NOT_GIVEN, - delay_window: Optional[float] | NotGiven = NOT_GIVEN, - evaluation_window: Optional[float] | NotGiven = NOT_GIVEN, - uses_ml_model: bool | NotGiven = NOT_GIVEN, - uses_production_data: bool | NotGiven = NOT_GIVEN, - uses_reference_dataset: bool | NotGiven = NOT_GIVEN, - uses_training_dataset: bool | NotGiven = NOT_GIVEN, - uses_validation_dataset: bool | NotGiven = NOT_GIVEN, + archived: bool | Omit = omit, + default_to_all_pipelines: Optional[bool] | Omit = omit, + delay_window: Optional[float] | Omit = omit, + evaluation_window: Optional[float] | Omit = omit, + exclude_pipelines: Optional[SequenceNotStr[str]] | Omit = omit, + include_historical_data: Optional[bool] | Omit = omit, + include_pipelines: Optional[SequenceNotStr[str]] | Omit = omit, + uses_ml_model: bool | Omit = omit, + uses_production_data: bool | Omit = omit, + uses_reference_dataset: bool | Omit = omit, + uses_training_dataset: bool | Omit = omit, + uses_validation_dataset: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestCreateResponse: """ Create a test. @@ -385,11 +409,23 @@ async def create( archived: Whether the test is archived. + default_to_all_pipelines: Whether to apply the test to all pipelines (data sources) or to a specific set + of pipelines. Only applies to tests that use production data. + delay_window: The delay window in seconds. Only applies to tests that use production data. evaluation_window: The evaluation window in seconds. Only applies to tests that use production data. + exclude_pipelines: Array of pipelines (data sources) to which the test should not be applied. Only + applies to tests that use production data. + + include_historical_data: Whether to include historical data in the test result. Only applies to tests + that use production data. + + include_pipelines: Array of pipelines (data sources) to which the test should be applied. Only + applies to tests that use production data. + uses_ml_model: Whether the test uses an ML model. uses_production_data: Whether the test uses production data (monitoring mode only). @@ -420,8 +456,12 @@ async def create( "thresholds": thresholds, "type": type, "archived": archived, + "default_to_all_pipelines": default_to_all_pipelines, "delay_window": delay_window, "evaluation_window": evaluation_window, + "exclude_pipelines": exclude_pipelines, + "include_historical_data": include_historical_data, + "include_pipelines": include_pipelines, "uses_ml_model": uses_ml_model, "uses_production_data": uses_production_data, "uses_reference_dataset": uses_reference_dataset, @@ -446,7 +486,7 @@ async def update( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestUpdateResponse: """ Update tests. @@ -475,19 +515,19 @@ async def list( self, project_id: str, *, - include_archived: bool | NotGiven = NOT_GIVEN, - origin_version_id: Optional[str] | NotGiven = NOT_GIVEN, - page: int | NotGiven = NOT_GIVEN, - per_page: int | NotGiven = NOT_GIVEN, - suggested: bool | NotGiven = NOT_GIVEN, - type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | NotGiven = NOT_GIVEN, - uses_production_data: Optional[bool] | NotGiven = NOT_GIVEN, + include_archived: bool | Omit = omit, + origin_version_id: Optional[str] | Omit = omit, + page: int | Omit = omit, + per_page: int | Omit = omit, + suggested: bool | Omit = omit, + type: Literal["integrity", "consistency", "performance", "fairness", "robustness"] | Omit = omit, + uses_production_data: Optional[bool] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> TestListResponse: """ List tests under a project. diff --git a/src/openlayer/resources/storage/presigned_url.py b/src/openlayer/resources/storage/presigned_url.py index 2ed0ace6..96886fe8 100644 --- a/src/openlayer/resources/storage/presigned_url.py +++ b/src/openlayer/resources/storage/presigned_url.py @@ -4,7 +4,7 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Query, Headers, NotGiven, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -50,7 +50,7 @@ def create( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> PresignedURLCreateResponse: """ Retrieve a presigned url to post storage artifacts. @@ -110,7 +110,7 @@ async def create( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> PresignedURLCreateResponse: """ Retrieve a presigned url to post storage artifacts. diff --git a/src/openlayer/types/inference_pipeline_retrieve_response.py b/src/openlayer/types/inference_pipeline_retrieve_response.py index b6d61869..2e589f5f 100644 --- a/src/openlayer/types/inference_pipeline_retrieve_response.py +++ b/src/openlayer/types/inference_pipeline_retrieve_response.py @@ -1,8 +1,8 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Optional +from typing import List, Union, Optional from datetime import date, datetime -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo @@ -11,6 +11,18 @@ __all__ = [ "InferencePipelineRetrieveResponse", "Links", + "DataBackend", + "DataBackendUnionMember0", + "DataBackendUnionMember0Config", + "DataBackendBackendType", + "DataBackendUnionMember2", + "DataBackendUnionMember2Config", + "DataBackendUnionMember3", + "DataBackendUnionMember3Config", + "DataBackendUnionMember4", + "DataBackendUnionMember4Config", + "DataBackendUnionMember5", + "DataBackendUnionMember5Config", "Project", "ProjectLinks", "ProjectGitRepo", @@ -23,6 +35,167 @@ class Links(BaseModel): app: str +class DataBackendUnionMember0Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember0(BaseModel): + backend_type: Literal["bigquery"] = FieldInfo(alias="backendType") + + bigquery_connection_id: Optional[str] = FieldInfo(alias="bigqueryConnectionId", default=None) + + dataset_id: str = FieldInfo(alias="datasetId") + + project_id: str = FieldInfo(alias="projectId") + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + partition_type: Optional[Literal["DAY", "MONTH", "YEAR"]] = FieldInfo(alias="partitionType", default=None) + + +class DataBackendBackendType(BaseModel): + backend_type: Literal["default"] = FieldInfo(alias="backendType") + + +class DataBackendUnionMember2Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember2(BaseModel): + backend_type: Literal["snowflake"] = FieldInfo(alias="backendType") + + database: str + + schema_: str = FieldInfo(alias="schema") + + snowflake_connection_id: Optional[str] = FieldInfo(alias="snowflakeConnectionId", default=None) + + table: Optional[str] = None + + +class DataBackendUnionMember3Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember3(BaseModel): + backend_type: Literal["databricks_dtl"] = FieldInfo(alias="backendType") + + databricks_dtl_connection_id: Optional[str] = FieldInfo(alias="databricksDtlConnectionId", default=None) + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + +class DataBackendUnionMember4Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember4(BaseModel): + backend_type: Literal["redshift"] = FieldInfo(alias="backendType") + + redshift_connection_id: Optional[str] = FieldInfo(alias="redshiftConnectionId", default=None) + + schema_name: str = FieldInfo(alias="schemaName") + + table_name: str = FieldInfo(alias="tableName") + + +class DataBackendUnionMember5Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember5(BaseModel): + backend_type: Literal["postgres"] = FieldInfo(alias="backendType") + + database: str + + postgres_connection_id: Optional[str] = FieldInfo(alias="postgresConnectionId", default=None) + + schema_: str = FieldInfo(alias="schema") + + table: Optional[str] = None + + +DataBackend: TypeAlias = Union[ + DataBackendUnionMember0, + DataBackendBackendType, + DataBackendUnionMember2, + DataBackendUnionMember3, + DataBackendUnionMember4, + DataBackendUnionMember5, + None, +] + + class ProjectLinks(BaseModel): app: str @@ -203,8 +376,16 @@ class InferencePipelineRetrieveResponse(BaseModel): total_goal_count: int = FieldInfo(alias="totalGoalCount") """The total number of tests.""" + data_backend: Optional[DataBackend] = FieldInfo(alias="dataBackend", default=None) + + date_last_polled: Optional[datetime] = FieldInfo(alias="dateLastPolled", default=None) + """The last time the data was polled.""" + project: Optional[Project] = None + total_records_count: Optional[int] = FieldInfo(alias="totalRecordsCount", default=None) + """The total number of records in the data backend.""" + workspace: Optional[Workspace] = None workspace_id: Optional[str] = FieldInfo(alias="workspaceId", default=None) diff --git a/src/openlayer/types/inference_pipeline_update_response.py b/src/openlayer/types/inference_pipeline_update_response.py index e8a8638c..acf7011b 100644 --- a/src/openlayer/types/inference_pipeline_update_response.py +++ b/src/openlayer/types/inference_pipeline_update_response.py @@ -1,8 +1,8 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Optional +from typing import List, Union, Optional from datetime import date, datetime -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo @@ -11,6 +11,18 @@ __all__ = [ "InferencePipelineUpdateResponse", "Links", + "DataBackend", + "DataBackendUnionMember0", + "DataBackendUnionMember0Config", + "DataBackendBackendType", + "DataBackendUnionMember2", + "DataBackendUnionMember2Config", + "DataBackendUnionMember3", + "DataBackendUnionMember3Config", + "DataBackendUnionMember4", + "DataBackendUnionMember4Config", + "DataBackendUnionMember5", + "DataBackendUnionMember5Config", "Project", "ProjectLinks", "ProjectGitRepo", @@ -23,6 +35,167 @@ class Links(BaseModel): app: str +class DataBackendUnionMember0Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember0(BaseModel): + backend_type: Literal["bigquery"] = FieldInfo(alias="backendType") + + bigquery_connection_id: Optional[str] = FieldInfo(alias="bigqueryConnectionId", default=None) + + dataset_id: str = FieldInfo(alias="datasetId") + + project_id: str = FieldInfo(alias="projectId") + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + partition_type: Optional[Literal["DAY", "MONTH", "YEAR"]] = FieldInfo(alias="partitionType", default=None) + + +class DataBackendBackendType(BaseModel): + backend_type: Literal["default"] = FieldInfo(alias="backendType") + + +class DataBackendUnionMember2Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember2(BaseModel): + backend_type: Literal["snowflake"] = FieldInfo(alias="backendType") + + database: str + + schema_: str = FieldInfo(alias="schema") + + snowflake_connection_id: Optional[str] = FieldInfo(alias="snowflakeConnectionId", default=None) + + table: Optional[str] = None + + +class DataBackendUnionMember3Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember3(BaseModel): + backend_type: Literal["databricks_dtl"] = FieldInfo(alias="backendType") + + databricks_dtl_connection_id: Optional[str] = FieldInfo(alias="databricksDtlConnectionId", default=None) + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + +class DataBackendUnionMember4Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember4(BaseModel): + backend_type: Literal["redshift"] = FieldInfo(alias="backendType") + + redshift_connection_id: Optional[str] = FieldInfo(alias="redshiftConnectionId", default=None) + + schema_name: str = FieldInfo(alias="schemaName") + + table_name: str = FieldInfo(alias="tableName") + + +class DataBackendUnionMember5Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember5(BaseModel): + backend_type: Literal["postgres"] = FieldInfo(alias="backendType") + + database: str + + postgres_connection_id: Optional[str] = FieldInfo(alias="postgresConnectionId", default=None) + + schema_: str = FieldInfo(alias="schema") + + table: Optional[str] = None + + +DataBackend: TypeAlias = Union[ + DataBackendUnionMember0, + DataBackendBackendType, + DataBackendUnionMember2, + DataBackendUnionMember3, + DataBackendUnionMember4, + DataBackendUnionMember5, + None, +] + + class ProjectLinks(BaseModel): app: str @@ -203,8 +376,16 @@ class InferencePipelineUpdateResponse(BaseModel): total_goal_count: int = FieldInfo(alias="totalGoalCount") """The total number of tests.""" + data_backend: Optional[DataBackend] = FieldInfo(alias="dataBackend", default=None) + + date_last_polled: Optional[datetime] = FieldInfo(alias="dateLastPolled", default=None) + """The last time the data was polled.""" + project: Optional[Project] = None + total_records_count: Optional[int] = FieldInfo(alias="totalRecordsCount", default=None) + """The total number of records in the data backend.""" + workspace: Optional[Workspace] = None workspace_id: Optional[str] = FieldInfo(alias="workspaceId", default=None) diff --git a/src/openlayer/types/projects/inference_pipeline_create_params.py b/src/openlayer/types/projects/inference_pipeline_create_params.py index 14e6d11b..d6879197 100644 --- a/src/openlayer/types/projects/inference_pipeline_create_params.py +++ b/src/openlayer/types/projects/inference_pipeline_create_params.py @@ -2,13 +2,29 @@ from __future__ import annotations -from typing import Optional -from typing_extensions import Literal, Required, Annotated, TypedDict +from typing import Union, Optional +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict from ..._types import SequenceNotStr from ..._utils import PropertyInfo -__all__ = ["InferencePipelineCreateParams", "Project", "Workspace"] +__all__ = [ + "InferencePipelineCreateParams", + "DataBackend", + "DataBackendUnionMember0", + "DataBackendUnionMember0Config", + "DataBackendBackendType", + "DataBackendUnionMember2", + "DataBackendUnionMember2Config", + "DataBackendUnionMember3", + "DataBackendUnionMember3Config", + "DataBackendUnionMember4", + "DataBackendUnionMember4Config", + "DataBackendUnionMember5", + "DataBackendUnionMember5Config", + "Project", + "Workspace", +] class InferencePipelineCreateParams(TypedDict, total=False): @@ -18,11 +34,218 @@ class InferencePipelineCreateParams(TypedDict, total=False): name: Required[str] """The inference pipeline name.""" + data_backend: Annotated[Optional[DataBackend], PropertyInfo(alias="dataBackend")] + project: Optional[Project] workspace: Optional[Workspace] +class DataBackendUnionMember0Config(TypedDict, total=False): + ground_truth_column_name: Annotated[Optional[str], PropertyInfo(alias="groundTruthColumnName")] + """Name of the column with the ground truths.""" + + human_feedback_column_name: Annotated[Optional[str], PropertyInfo(alias="humanFeedbackColumnName")] + """Name of the column with human feedback.""" + + inference_id_column_name: Annotated[Optional[str], PropertyInfo(alias="inferenceIdColumnName")] + """Name of the column with the inference ids. + + This is useful if you want to update rows at a later point in time. If not + provided, a unique id is generated by Openlayer. + """ + + latency_column_name: Annotated[Optional[str], PropertyInfo(alias="latencyColumnName")] + """Name of the column with the latencies.""" + + timestamp_column_name: Annotated[Optional[str], PropertyInfo(alias="timestampColumnName")] + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember0(TypedDict, total=False): + backend_type: Required[Annotated[Literal["bigquery"], PropertyInfo(alias="backendType")]] + + bigquery_connection_id: Required[Annotated[Optional[str], PropertyInfo(alias="bigqueryConnectionId")]] + + config: Required[DataBackendUnionMember0Config] + + dataset_id: Required[Annotated[str, PropertyInfo(alias="datasetId")]] + + project_id: Required[Annotated[str, PropertyInfo(alias="projectId")]] + + table_id: Required[Annotated[Optional[str], PropertyInfo(alias="tableId")]] + + partition_type: Annotated[Optional[Literal["DAY", "MONTH", "YEAR"]], PropertyInfo(alias="partitionType")] + + +class DataBackendBackendType(TypedDict, total=False): + backend_type: Required[Annotated[Literal["default"], PropertyInfo(alias="backendType")]] + + +class DataBackendUnionMember2Config(TypedDict, total=False): + ground_truth_column_name: Annotated[Optional[str], PropertyInfo(alias="groundTruthColumnName")] + """Name of the column with the ground truths.""" + + human_feedback_column_name: Annotated[Optional[str], PropertyInfo(alias="humanFeedbackColumnName")] + """Name of the column with human feedback.""" + + inference_id_column_name: Annotated[Optional[str], PropertyInfo(alias="inferenceIdColumnName")] + """Name of the column with the inference ids. + + This is useful if you want to update rows at a later point in time. If not + provided, a unique id is generated by Openlayer. + """ + + latency_column_name: Annotated[Optional[str], PropertyInfo(alias="latencyColumnName")] + """Name of the column with the latencies.""" + + timestamp_column_name: Annotated[Optional[str], PropertyInfo(alias="timestampColumnName")] + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember2(TypedDict, total=False): + backend_type: Required[Annotated[Literal["snowflake"], PropertyInfo(alias="backendType")]] + + config: Required[DataBackendUnionMember2Config] + + database: Required[str] + + schema: Required[str] + + snowflake_connection_id: Required[Annotated[Optional[str], PropertyInfo(alias="snowflakeConnectionId")]] + + table: Required[Optional[str]] + + +class DataBackendUnionMember3Config(TypedDict, total=False): + ground_truth_column_name: Annotated[Optional[str], PropertyInfo(alias="groundTruthColumnName")] + """Name of the column with the ground truths.""" + + human_feedback_column_name: Annotated[Optional[str], PropertyInfo(alias="humanFeedbackColumnName")] + """Name of the column with human feedback.""" + + inference_id_column_name: Annotated[Optional[str], PropertyInfo(alias="inferenceIdColumnName")] + """Name of the column with the inference ids. + + This is useful if you want to update rows at a later point in time. If not + provided, a unique id is generated by Openlayer. + """ + + latency_column_name: Annotated[Optional[str], PropertyInfo(alias="latencyColumnName")] + """Name of the column with the latencies.""" + + timestamp_column_name: Annotated[Optional[str], PropertyInfo(alias="timestampColumnName")] + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember3(TypedDict, total=False): + backend_type: Required[Annotated[Literal["databricks_dtl"], PropertyInfo(alias="backendType")]] + + config: Required[DataBackendUnionMember3Config] + + databricks_dtl_connection_id: Required[Annotated[Optional[str], PropertyInfo(alias="databricksDtlConnectionId")]] + + table_id: Required[Annotated[Optional[str], PropertyInfo(alias="tableId")]] + + +class DataBackendUnionMember4Config(TypedDict, total=False): + ground_truth_column_name: Annotated[Optional[str], PropertyInfo(alias="groundTruthColumnName")] + """Name of the column with the ground truths.""" + + human_feedback_column_name: Annotated[Optional[str], PropertyInfo(alias="humanFeedbackColumnName")] + """Name of the column with human feedback.""" + + inference_id_column_name: Annotated[Optional[str], PropertyInfo(alias="inferenceIdColumnName")] + """Name of the column with the inference ids. + + This is useful if you want to update rows at a later point in time. If not + provided, a unique id is generated by Openlayer. + """ + + latency_column_name: Annotated[Optional[str], PropertyInfo(alias="latencyColumnName")] + """Name of the column with the latencies.""" + + timestamp_column_name: Annotated[Optional[str], PropertyInfo(alias="timestampColumnName")] + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember4(TypedDict, total=False): + backend_type: Required[Annotated[Literal["redshift"], PropertyInfo(alias="backendType")]] + + config: Required[DataBackendUnionMember4Config] + + redshift_connection_id: Required[Annotated[Optional[str], PropertyInfo(alias="redshiftConnectionId")]] + + schema_name: Required[Annotated[str, PropertyInfo(alias="schemaName")]] + + table_name: Required[Annotated[str, PropertyInfo(alias="tableName")]] + + +class DataBackendUnionMember5Config(TypedDict, total=False): + ground_truth_column_name: Annotated[Optional[str], PropertyInfo(alias="groundTruthColumnName")] + """Name of the column with the ground truths.""" + + human_feedback_column_name: Annotated[Optional[str], PropertyInfo(alias="humanFeedbackColumnName")] + """Name of the column with human feedback.""" + + inference_id_column_name: Annotated[Optional[str], PropertyInfo(alias="inferenceIdColumnName")] + """Name of the column with the inference ids. + + This is useful if you want to update rows at a later point in time. If not + provided, a unique id is generated by Openlayer. + """ + + latency_column_name: Annotated[Optional[str], PropertyInfo(alias="latencyColumnName")] + """Name of the column with the latencies.""" + + timestamp_column_name: Annotated[Optional[str], PropertyInfo(alias="timestampColumnName")] + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember5(TypedDict, total=False): + backend_type: Required[Annotated[Literal["postgres"], PropertyInfo(alias="backendType")]] + + config: Required[DataBackendUnionMember5Config] + + database: Required[str] + + postgres_connection_id: Required[Annotated[Optional[str], PropertyInfo(alias="postgresConnectionId")]] + + schema: Required[str] + + table: Required[Optional[str]] + + +DataBackend: TypeAlias = Union[ + DataBackendUnionMember0, + DataBackendBackendType, + DataBackendUnionMember2, + DataBackendUnionMember3, + DataBackendUnionMember4, + DataBackendUnionMember5, +] + + class Project(TypedDict, total=False): name: Required[str] """The project name.""" diff --git a/src/openlayer/types/projects/inference_pipeline_create_response.py b/src/openlayer/types/projects/inference_pipeline_create_response.py index a6085579..19764ce8 100644 --- a/src/openlayer/types/projects/inference_pipeline_create_response.py +++ b/src/openlayer/types/projects/inference_pipeline_create_response.py @@ -1,8 +1,8 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Optional +from typing import List, Union, Optional from datetime import date, datetime -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo @@ -11,6 +11,18 @@ __all__ = [ "InferencePipelineCreateResponse", "Links", + "DataBackend", + "DataBackendUnionMember0", + "DataBackendUnionMember0Config", + "DataBackendBackendType", + "DataBackendUnionMember2", + "DataBackendUnionMember2Config", + "DataBackendUnionMember3", + "DataBackendUnionMember3Config", + "DataBackendUnionMember4", + "DataBackendUnionMember4Config", + "DataBackendUnionMember5", + "DataBackendUnionMember5Config", "Project", "ProjectLinks", "ProjectGitRepo", @@ -23,6 +35,167 @@ class Links(BaseModel): app: str +class DataBackendUnionMember0Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember0(BaseModel): + backend_type: Literal["bigquery"] = FieldInfo(alias="backendType") + + bigquery_connection_id: Optional[str] = FieldInfo(alias="bigqueryConnectionId", default=None) + + dataset_id: str = FieldInfo(alias="datasetId") + + project_id: str = FieldInfo(alias="projectId") + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + partition_type: Optional[Literal["DAY", "MONTH", "YEAR"]] = FieldInfo(alias="partitionType", default=None) + + +class DataBackendBackendType(BaseModel): + backend_type: Literal["default"] = FieldInfo(alias="backendType") + + +class DataBackendUnionMember2Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember2(BaseModel): + backend_type: Literal["snowflake"] = FieldInfo(alias="backendType") + + database: str + + schema_: str = FieldInfo(alias="schema") + + snowflake_connection_id: Optional[str] = FieldInfo(alias="snowflakeConnectionId", default=None) + + table: Optional[str] = None + + +class DataBackendUnionMember3Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember3(BaseModel): + backend_type: Literal["databricks_dtl"] = FieldInfo(alias="backendType") + + databricks_dtl_connection_id: Optional[str] = FieldInfo(alias="databricksDtlConnectionId", default=None) + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + +class DataBackendUnionMember4Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember4(BaseModel): + backend_type: Literal["redshift"] = FieldInfo(alias="backendType") + + redshift_connection_id: Optional[str] = FieldInfo(alias="redshiftConnectionId", default=None) + + schema_name: str = FieldInfo(alias="schemaName") + + table_name: str = FieldInfo(alias="tableName") + + +class DataBackendUnionMember5Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class DataBackendUnionMember5(BaseModel): + backend_type: Literal["postgres"] = FieldInfo(alias="backendType") + + database: str + + postgres_connection_id: Optional[str] = FieldInfo(alias="postgresConnectionId", default=None) + + schema_: str = FieldInfo(alias="schema") + + table: Optional[str] = None + + +DataBackend: TypeAlias = Union[ + DataBackendUnionMember0, + DataBackendBackendType, + DataBackendUnionMember2, + DataBackendUnionMember3, + DataBackendUnionMember4, + DataBackendUnionMember5, + None, +] + + class ProjectLinks(BaseModel): app: str @@ -203,8 +376,16 @@ class InferencePipelineCreateResponse(BaseModel): total_goal_count: int = FieldInfo(alias="totalGoalCount") """The total number of tests.""" + data_backend: Optional[DataBackend] = FieldInfo(alias="dataBackend", default=None) + + date_last_polled: Optional[datetime] = FieldInfo(alias="dateLastPolled", default=None) + """The last time the data was polled.""" + project: Optional[Project] = None + total_records_count: Optional[int] = FieldInfo(alias="totalRecordsCount", default=None) + """The total number of records in the data backend.""" + workspace: Optional[Workspace] = None workspace_id: Optional[str] = FieldInfo(alias="workspaceId", default=None) diff --git a/src/openlayer/types/projects/inference_pipeline_list_response.py b/src/openlayer/types/projects/inference_pipeline_list_response.py index 0d5be4eb..9163afad 100644 --- a/src/openlayer/types/projects/inference_pipeline_list_response.py +++ b/src/openlayer/types/projects/inference_pipeline_list_response.py @@ -1,8 +1,8 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Optional +from typing import List, Union, Optional from datetime import date, datetime -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo @@ -12,6 +12,18 @@ "InferencePipelineListResponse", "Item", "ItemLinks", + "ItemDataBackend", + "ItemDataBackendUnionMember0", + "ItemDataBackendUnionMember0Config", + "ItemDataBackendBackendType", + "ItemDataBackendUnionMember2", + "ItemDataBackendUnionMember2Config", + "ItemDataBackendUnionMember3", + "ItemDataBackendUnionMember3Config", + "ItemDataBackendUnionMember4", + "ItemDataBackendUnionMember4Config", + "ItemDataBackendUnionMember5", + "ItemDataBackendUnionMember5Config", "ItemProject", "ItemProjectLinks", "ItemProjectGitRepo", @@ -24,6 +36,167 @@ class ItemLinks(BaseModel): app: str +class ItemDataBackendUnionMember0Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class ItemDataBackendUnionMember0(BaseModel): + backend_type: Literal["bigquery"] = FieldInfo(alias="backendType") + + bigquery_connection_id: Optional[str] = FieldInfo(alias="bigqueryConnectionId", default=None) + + dataset_id: str = FieldInfo(alias="datasetId") + + project_id: str = FieldInfo(alias="projectId") + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + partition_type: Optional[Literal["DAY", "MONTH", "YEAR"]] = FieldInfo(alias="partitionType", default=None) + + +class ItemDataBackendBackendType(BaseModel): + backend_type: Literal["default"] = FieldInfo(alias="backendType") + + +class ItemDataBackendUnionMember2Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class ItemDataBackendUnionMember2(BaseModel): + backend_type: Literal["snowflake"] = FieldInfo(alias="backendType") + + database: str + + schema_: str = FieldInfo(alias="schema") + + snowflake_connection_id: Optional[str] = FieldInfo(alias="snowflakeConnectionId", default=None) + + table: Optional[str] = None + + +class ItemDataBackendUnionMember3Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class ItemDataBackendUnionMember3(BaseModel): + backend_type: Literal["databricks_dtl"] = FieldInfo(alias="backendType") + + databricks_dtl_connection_id: Optional[str] = FieldInfo(alias="databricksDtlConnectionId", default=None) + + table_id: Optional[str] = FieldInfo(alias="tableId", default=None) + + +class ItemDataBackendUnionMember4Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class ItemDataBackendUnionMember4(BaseModel): + backend_type: Literal["redshift"] = FieldInfo(alias="backendType") + + redshift_connection_id: Optional[str] = FieldInfo(alias="redshiftConnectionId", default=None) + + schema_name: str = FieldInfo(alias="schemaName") + + table_name: str = FieldInfo(alias="tableName") + + +class ItemDataBackendUnionMember5Config(BaseModel): + ground_truth_column_name: Optional[str] = FieldInfo(alias="groundTruthColumnName", default=None) + """Name of the column with the ground truths.""" + + human_feedback_column_name: Optional[str] = FieldInfo(alias="humanFeedbackColumnName", default=None) + """Name of the column with human feedback.""" + + latency_column_name: Optional[str] = FieldInfo(alias="latencyColumnName", default=None) + """Name of the column with the latencies.""" + + timestamp_column_name: Optional[str] = FieldInfo(alias="timestampColumnName", default=None) + """Name of the column with the timestamps. + + Timestamps must be in UNIX sec format. If not provided, the upload timestamp is + used. + """ + + +class ItemDataBackendUnionMember5(BaseModel): + backend_type: Literal["postgres"] = FieldInfo(alias="backendType") + + database: str + + postgres_connection_id: Optional[str] = FieldInfo(alias="postgresConnectionId", default=None) + + schema_: str = FieldInfo(alias="schema") + + table: Optional[str] = None + + +ItemDataBackend: TypeAlias = Union[ + ItemDataBackendUnionMember0, + ItemDataBackendBackendType, + ItemDataBackendUnionMember2, + ItemDataBackendUnionMember3, + ItemDataBackendUnionMember4, + ItemDataBackendUnionMember5, + None, +] + + class ItemProjectLinks(BaseModel): app: str @@ -204,8 +377,16 @@ class Item(BaseModel): total_goal_count: int = FieldInfo(alias="totalGoalCount") """The total number of tests.""" + data_backend: Optional[ItemDataBackend] = FieldInfo(alias="dataBackend", default=None) + + date_last_polled: Optional[datetime] = FieldInfo(alias="dateLastPolled", default=None) + """The last time the data was polled.""" + project: Optional[ItemProject] = None + total_records_count: Optional[int] = FieldInfo(alias="totalRecordsCount", default=None) + """The total number of records in the data backend.""" + workspace: Optional[ItemWorkspace] = None workspace_id: Optional[str] = FieldInfo(alias="workspaceId", default=None) diff --git a/src/openlayer/types/projects/test_create_params.py b/src/openlayer/types/projects/test_create_params.py index b2ebdba0..98f615ae 100644 --- a/src/openlayer/types/projects/test_create_params.py +++ b/src/openlayer/types/projects/test_create_params.py @@ -73,6 +73,12 @@ class TestCreateParams(TypedDict, total=False): archived: bool """Whether the test is archived.""" + default_to_all_pipelines: Annotated[Optional[bool], PropertyInfo(alias="defaultToAllPipelines")] + """ + Whether to apply the test to all pipelines (data sources) or to a specific set + of pipelines. Only applies to tests that use production data. + """ + delay_window: Annotated[Optional[float], PropertyInfo(alias="delayWindow")] """The delay window in seconds. Only applies to tests that use production data.""" @@ -82,6 +88,24 @@ class TestCreateParams(TypedDict, total=False): Only applies to tests that use production data. """ + exclude_pipelines: Annotated[Optional[SequenceNotStr[str]], PropertyInfo(alias="excludePipelines")] + """Array of pipelines (data sources) to which the test should not be applied. + + Only applies to tests that use production data. + """ + + include_historical_data: Annotated[Optional[bool], PropertyInfo(alias="includeHistoricalData")] + """Whether to include historical data in the test result. + + Only applies to tests that use production data. + """ + + include_pipelines: Annotated[Optional[SequenceNotStr[str]], PropertyInfo(alias="includePipelines")] + """Array of pipelines (data sources) to which the test should be applied. + + Only applies to tests that use production data. + """ + uses_ml_model: Annotated[bool, PropertyInfo(alias="usesMlModel")] """Whether the test uses an ML model.""" diff --git a/src/openlayer/types/projects/test_create_response.py b/src/openlayer/types/projects/test_create_response.py index 91d6d6de..83d4e4c4 100644 --- a/src/openlayer/types/projects/test_create_response.py +++ b/src/openlayer/types/projects/test_create_response.py @@ -168,6 +168,12 @@ class TestCreateResponse(BaseModel): archived: Optional[bool] = None """Whether the test is archived.""" + default_to_all_pipelines: Optional[bool] = FieldInfo(alias="defaultToAllPipelines", default=None) + """ + Whether to apply the test to all pipelines (data sources) or to a specific set + of pipelines. Only applies to tests that use production data. + """ + delay_window: Optional[float] = FieldInfo(alias="delayWindow", default=None) """The delay window in seconds. Only applies to tests that use production data.""" @@ -177,6 +183,24 @@ class TestCreateResponse(BaseModel): Only applies to tests that use production data. """ + exclude_pipelines: Optional[List[str]] = FieldInfo(alias="excludePipelines", default=None) + """Array of pipelines (data sources) to which the test should not be applied. + + Only applies to tests that use production data. + """ + + include_historical_data: Optional[bool] = FieldInfo(alias="includeHistoricalData", default=None) + """Whether to include historical data in the test result. + + Only applies to tests that use production data. + """ + + include_pipelines: Optional[List[str]] = FieldInfo(alias="includePipelines", default=None) + """Array of pipelines (data sources) to which the test should be applied. + + Only applies to tests that use production data. + """ + uses_ml_model: Optional[bool] = FieldInfo(alias="usesMlModel", default=None) """Whether the test uses an ML model.""" diff --git a/src/openlayer/types/projects/test_list_response.py b/src/openlayer/types/projects/test_list_response.py index c8afd5f5..cc7343c3 100644 --- a/src/openlayer/types/projects/test_list_response.py +++ b/src/openlayer/types/projects/test_list_response.py @@ -169,6 +169,12 @@ class Item(BaseModel): archived: Optional[bool] = None """Whether the test is archived.""" + default_to_all_pipelines: Optional[bool] = FieldInfo(alias="defaultToAllPipelines", default=None) + """ + Whether to apply the test to all pipelines (data sources) or to a specific set + of pipelines. Only applies to tests that use production data. + """ + delay_window: Optional[float] = FieldInfo(alias="delayWindow", default=None) """The delay window in seconds. Only applies to tests that use production data.""" @@ -178,6 +184,24 @@ class Item(BaseModel): Only applies to tests that use production data. """ + exclude_pipelines: Optional[List[str]] = FieldInfo(alias="excludePipelines", default=None) + """Array of pipelines (data sources) to which the test should not be applied. + + Only applies to tests that use production data. + """ + + include_historical_data: Optional[bool] = FieldInfo(alias="includeHistoricalData", default=None) + """Whether to include historical data in the test result. + + Only applies to tests that use production data. + """ + + include_pipelines: Optional[List[str]] = FieldInfo(alias="includePipelines", default=None) + """Array of pipelines (data sources) to which the test should be applied. + + Only applies to tests that use production data. + """ + uses_ml_model: Optional[bool] = FieldInfo(alias="usesMlModel", default=None) """Whether the test uses an ML model.""" diff --git a/tests/api_resources/projects/test_inference_pipelines.py b/tests/api_resources/projects/test_inference_pipelines.py index e92bf727..eb112725 100644 --- a/tests/api_resources/projects/test_inference_pipelines.py +++ b/tests/api_resources/projects/test_inference_pipelines.py @@ -35,6 +35,21 @@ def test_method_create_with_all_params(self, client: Openlayer) -> None: project_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", description="This pipeline is used for production.", name="production", + data_backend={ + "backend_type": "bigquery", + "bigquery_connection_id": "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + "config": { + "ground_truth_column_name": "ground_truth", + "human_feedback_column_name": "human_feedback", + "inference_id_column_name": "id", + "latency_column_name": "latency", + "timestamp_column_name": "timestamp", + }, + "dataset_id": "my-dataset", + "project_id": "my-project", + "table_id": "my-table", + "partition_type": "DAY", + }, project={ "name": "My Project", "task_type": "llm-base", @@ -156,6 +171,21 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenlayer) project_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", description="This pipeline is used for production.", name="production", + data_backend={ + "backend_type": "bigquery", + "bigquery_connection_id": "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + "config": { + "ground_truth_column_name": "ground_truth", + "human_feedback_column_name": "human_feedback", + "inference_id_column_name": "id", + "latency_column_name": "latency", + "timestamp_column_name": "timestamp", + }, + "dataset_id": "my-dataset", + "project_id": "my-project", + "table_id": "my-table", + "partition_type": "DAY", + }, project={ "name": "My Project", "task_type": "llm-base", diff --git a/tests/api_resources/projects/test_tests.py b/tests/api_resources/projects/test_tests.py index a37a33ba..7fae681f 100644 --- a/tests/api_resources/projects/test_tests.py +++ b/tests/api_resources/projects/test_tests.py @@ -57,8 +57,12 @@ def test_method_create_with_all_params(self, client: Openlayer) -> None: ], type="integrity", archived=False, + default_to_all_pipelines=True, delay_window=0, evaluation_window=3600, + exclude_pipelines=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + include_historical_data=True, + include_pipelines=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], uses_ml_model=False, uses_production_data=False, uses_reference_dataset=False, @@ -249,8 +253,12 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenlayer) ], type="integrity", archived=False, + default_to_all_pipelines=True, delay_window=0, evaluation_window=3600, + exclude_pipelines=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + include_historical_data=True, + include_pipelines=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], uses_ml_model=False, uses_production_data=False, uses_reference_dataset=False, diff --git a/tests/test_client.py b/tests/test_client.py index 2cb1bfd6..0ea3a4e5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,51 +59,49 @@ def _get_open_connections(client: Openlayer | AsyncOpenlayer) -> int: class TestOpenlayer: - client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: Openlayer) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Openlayer) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: Openlayer) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: Openlayer) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = Openlayer( @@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = Openlayer( @@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: Openlayer) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -192,12 +193,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, client: Openlayer) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: Openlayer) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -274,6 +273,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -285,6 +286,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = Openlayer( @@ -295,6 +298,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = Openlayer( @@ -305,6 +310,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -316,14 +323,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = Openlayer( + test_client = Openlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = Openlayer( + test_client2 = Openlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -332,10 +339,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -373,8 +383,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -385,7 +397,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -396,7 +408,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -407,8 +419,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -418,7 +430,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -429,8 +441,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -443,7 +455,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -457,7 +469,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -500,7 +512,7 @@ def test_multipart_repeating_array(self, client: Openlayer) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: Openlayer) -> None: class Model1(BaseModel): name: str @@ -509,12 +521,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: Openlayer) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -525,18 +537,18 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Openlayer) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -552,7 +564,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -564,6 +576,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(OPENLAYER_BASE_URL="http://localhost:5000/from/env"): client = Openlayer(api_key=api_key, _strict_response_validation=True) @@ -591,6 +605,7 @@ def test_base_url_trailing_slash(self, client: Openlayer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -614,6 +629,7 @@ def test_base_url_no_trailing_slash(self, client: Openlayer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -637,35 +653,36 @@ def test_absolute_request_url(self, client: Openlayer) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: Openlayer) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -685,11 +702,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -712,9 +732,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = Openlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: Openlayer + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -742,7 +762,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien ], ).__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -765,7 +785,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client } ], ).__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -911,83 +931,77 @@ def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - def test_follow_redirects(self, respx_mock: MockRouter) -> None: + def test_follow_redirects(self, respx_mock: MockRouter, client: Openlayer) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Openlayer) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - self.client.post( - "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response - ) + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" class TestAsyncOpenlayer: - client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncOpenlayer) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncOpenlayer) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncOpenlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -1020,8 +1034,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncOpenlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -1057,13 +1072,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncOpenlayer) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -1074,12 +1091,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, async_client: AsyncOpenlayer) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1136,12 +1153,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncOpenlayer) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1156,6 +1173,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1167,6 +1186,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncOpenlayer( @@ -1177,6 +1198,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncOpenlayer( @@ -1187,6 +1210,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1197,15 +1222,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncOpenlayer( + async def test_default_headers_option(self) -> None: + test_client = AsyncOpenlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncOpenlayer( + test_client2 = AsyncOpenlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1214,10 +1239,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1237,7 +1265,7 @@ def test_validate_headers(self) -> None: ) assert request2.headers.get("Authorization") is None - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncOpenlayer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1255,8 +1283,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1267,7 +1297,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1278,7 +1308,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1289,8 +1319,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1300,7 +1330,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1311,8 +1341,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Openlayer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1325,7 +1355,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1339,7 +1369,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1382,7 +1412,7 @@ def test_multipart_repeating_array(self, async_client: AsyncOpenlayer) -> None: ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: class Model1(BaseModel): name: str @@ -1391,12 +1421,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1407,18 +1437,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncOpenlayer + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1434,11 +1466,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncOpenlayer( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1448,7 +1480,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(OPENLAYER_BASE_URL="http://localhost:5000/from/env"): client = AsyncOpenlayer(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1468,7 +1502,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncOpenlayer) -> None: + async def test_base_url_trailing_slash(self, client: AsyncOpenlayer) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1477,6 +1511,7 @@ def test_base_url_trailing_slash(self, client: AsyncOpenlayer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1493,7 +1528,7 @@ def test_base_url_trailing_slash(self, client: AsyncOpenlayer) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncOpenlayer) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncOpenlayer) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1502,6 +1537,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncOpenlayer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1518,7 +1554,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncOpenlayer) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncOpenlayer) -> None: + async def test_absolute_request_url(self, client: AsyncOpenlayer) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1527,37 +1563,37 @@ def test_absolute_request_url(self, client: AsyncOpenlayer) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1568,7 +1604,6 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1580,11 +1615,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1607,13 +1645,12 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncOpenlayer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenlayer + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1640,7 +1677,7 @@ async def test_retrying_timeout_errors_doesnt_leak( ], ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -1665,12 +1702,11 @@ async def test_retrying_status_errors_doesnt_leak( } ], ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1716,7 +1752,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncOpenlayer, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1755,7 +1790,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncOpenlayer, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1818,26 +1852,26 @@ async def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenlayer) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - await self.client.post( + await async_client.post( "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response ) diff --git a/tests/test_models.py b/tests/test_models.py index ab95d39c..197e1412 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from openlayer._utils import PropertyInfo from openlayer._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from openlayer._models import BaseModel, construct_type +from openlayer._models import DISCRIMINATOR_CACHE, BaseModel, construct_type class BasicModel(BaseModel): @@ -809,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert not DISCRIMINATOR_CACHE.get(UnionType) m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -818,7 +818,7 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = DISCRIMINATOR_CACHE.get(UnionType) assert discriminator is not None m = construct_type( @@ -830,7 +830,7 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") diff --git a/tests/test_transform.py b/tests/test_transform.py index c8f3477f..f51f6e4a 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,7 +8,7 @@ import pytest -from openlayer._types import NOT_GIVEN, Base64FileInput +from openlayer._types import Base64FileInput, omit, not_given from openlayer._utils import ( PropertyInfo, transform as _transform, @@ -450,4 +450,11 @@ async def test_transform_skipping(use_async: bool) -> None: @pytest.mark.asyncio async def test_strips_notgiven(use_async: bool) -> None: assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} - assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {}