From f3c1553ff7d9cc5f3c7a5f138667d5ea83b48595 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Thu, 7 Aug 2025 22:03:22 -0400 Subject: [PATCH 01/17] Remove duplicate exceptions.py from root directory (#94) The exceptions.py file existed in both the root directory and langextract/ directory with identical content. This removes the duplicate from the root to avoid confusion and maintain proper package structure. --- exceptions.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 exceptions.py diff --git a/exceptions.py b/exceptions.py deleted file mode 100644 index 0199da56..00000000 --- a/exceptions.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base exceptions for LangExtract. - -This module defines the base exception class that all LangExtract exceptions -inherit from. Individual modules define their own specific exceptions. -""" - -__all__ = ["LangExtractError"] - - -class LangExtractError(Exception): - """Base exception for all LangExtract errors. - - All exceptions raised by LangExtract should inherit from this class. - This allows users to catch all LangExtract-specific errors with a single - except clause. - """ From 845258cd63ea3e4e360c7a9e89761ebd5413594a Mon Sep 17 00:00:00 2001 From: Wade <59910348+wade6716@users.noreply.github.com> Date: Fri, 8 Aug 2025 19:33:36 +0800 Subject: [PATCH 02/17] Fix unicode escaping in example generation (#98) --- langextract/prompting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langextract/prompting.py b/langextract/prompting.py index 4484273b..1a50e0c2 100644 --- a/langextract/prompting.py +++ b/langextract/prompting.py @@ -127,7 +127,7 @@ def format_example_as_text(self, example: data.ExampleData) -> str: else: answer = formatted_content.strip() elif self.format_type == data.FormatType.JSON: - formatted_content = json.dumps(data_dict, indent=2) + formatted_content = json.dumps(data_dict, indent=2, ensure_ascii=False) if self.fence_output: answer = f"```json\n{formatted_content.strip()}\n```" else: From 00acc436ed51e402f1be596e7618210cffb20edc Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Fri, 8 Aug 2025 07:44:22 -0400 Subject: [PATCH 03/17] Add provider registry infrastructure and custom provider plugin example (#97) Introduces a provider registry system enabling third-party providers to be dynamically registered and discovered through a plugin architecture. Users can now integrate custom LLM backends (Azure OpenAI, AWS Bedrock, custom inference servers) without modifying core LangExtract code. Fixes #80, #67, #54, #49, #48, #53 Key Changes: **Provider Registry** (`langextract/providers/registry.py`) - Pattern-based registration with priority resolution - Automatic discovery via Python entry points - Lazy loading for performance **Factory Enhancements** (`langextract/factory.py`) - `ModelConfig` dataclass for structured configuration - Explicit provider selection when patterns overlap - Full backward compatibility maintained **Plugin Example** (`examples/custom_provider_plugin/`) - Complete working example with entry point configuration - Shows how to create custom providers for any backend **Documentation** - Comprehensive provider system README with architecture diagrams - Step-by-step plugin creation guide **Dependencies** - Move openai to optional dependencies - Update tox.ini to include openai in test environments **Lint Fixes** - Add appropriate pylint suppressions for legitimate patterns - Fix unused variable warnings in tests - Address import and global statement warnings No anticipated breakage - full backward compatibility maintained. Given significant internal changes to provider loading, issues should be reported if unexpected behavior is encountered. --- CONTRIBUTING.md | 9 +- README.md | 17 +- examples/custom_provider_plugin/README.md | 88 +++ .../langextract_provider_example/__init__.py | 20 + .../langextract_provider_example/provider.py | 125 ++++ .../custom_provider_plugin/pyproject.toml | 38 ++ .../test_example_provider.py | 57 ++ langextract/__init__.py | 38 +- langextract/exceptions.py | 47 +- langextract/factory.py | 153 +++++ langextract/inference.py | 598 +++++------------- langextract/providers/README.md | 372 +++++++++++ langextract/providers/__init__.py | 85 +++ langextract/providers/gemini.py | 185 ++++++ langextract/providers/ollama.py | 271 ++++++++ langextract/providers/openai.py | 201 ++++++ langextract/providers/registry.py | 213 +++++++ pyproject.toml | 6 +- tests/factory_test.py | 315 +++++++++ tests/inference_test.py | 3 +- tests/init_test.py | 23 +- tests/registry_test.py | 197 ++++++ tests/test_live_api.py | 69 +- tests/test_ollama_integration.py | 1 - tox.ini | 8 +- 25 files changed, 2663 insertions(+), 476 deletions(-) create mode 100644 examples/custom_provider_plugin/README.md create mode 100644 examples/custom_provider_plugin/langextract_provider_example/__init__.py create mode 100644 examples/custom_provider_plugin/langextract_provider_example/provider.py create mode 100644 examples/custom_provider_plugin/pyproject.toml create mode 100644 examples/custom_provider_plugin/test_example_provider.py create mode 100644 langextract/factory.py create mode 100644 langextract/providers/README.md create mode 100644 langextract/providers/__init__.py create mode 100644 langextract/providers/gemini.py create mode 100644 langextract/providers/ollama.py create mode 100644 langextract/providers/openai.py create mode 100644 langextract/providers/registry.py create mode 100644 tests/factory_test.py create mode 100644 tests/registry_test.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aa5038d4..78672b47 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -105,7 +105,14 @@ For full testing across Python versions: tox # runs pylint + pytest on Python 3.10 and 3.11 ``` -### 5. Submit Your Pull Request +### 5. Adding Custom Model Providers + +If you want to add support for a new LLM provider, please refer to the [Provider System Documentation](langextract/providers/README.md). The recommended approach is to create an external plugin package rather than modifying the core library. This allows for: +- Independent versioning and releases +- Faster iteration without core review cycles +- Custom dependencies without affecting core users + +### 6. Submit Your Pull Request All submissions, including submissions by project members, require review. We use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) diff --git a/README.md b/README.md index 2cbe5820..b373ee00 100644 --- a/README.md +++ b/README.md @@ -255,7 +255,7 @@ result = lx.extract( ## Using OpenAI Models -LangExtract also supports OpenAI models. Example OpenAI configuration: +LangExtract supports OpenAI models (requires optional dependency: `pip install langextract[openai]`): ```python import langextract as lx @@ -264,8 +264,7 @@ result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=lx.inference.OpenAILanguageModel, - model_id="gpt-4o", + model_id="gpt-4o", # Automatically selects OpenAI provider api_key=os.environ.get('OPENAI_API_KEY'), fence_output=True, use_schema_constraints=False @@ -285,8 +284,7 @@ result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=lx.inference.OllamaLanguageModel, - model_id="gemma2:2b", # or any Ollama model + model_id="gemma2:2b", # Automatically selects Ollama provider model_url="http://localhost:11434", fence_output=False, use_schema_constraints=False @@ -328,6 +326,15 @@ with development, testing, and pull requests. You must sign a [Contributor License Agreement](https://cla.developers.google.com/about) before submitting patches. +### Adding Custom Model Providers + +LangExtract supports custom LLM providers through a plugin system. You can add support for new models by creating an external Python package that registers with LangExtract's provider registry. This allows you to: +- Add new model support without modifying the core library +- Distribute your provider independently +- Maintain custom dependencies + +For detailed instructions, see the [Provider System Documentation](langextract/providers/README.md). + ## Testing To run tests locally from the source: diff --git a/examples/custom_provider_plugin/README.md b/examples/custom_provider_plugin/README.md new file mode 100644 index 00000000..6aaf3795 --- /dev/null +++ b/examples/custom_provider_plugin/README.md @@ -0,0 +1,88 @@ +# Custom Provider Plugin Example + +This example demonstrates how to create a custom provider plugin that extends LangExtract with your own model backend. + +**Note**: This is an example included in the LangExtract repository for reference. It is not part of the LangExtract package and won't be installed when you `pip install langextract`. + +## Structure + +``` +custom_provider_plugin/ +├── pyproject.toml # Package configuration and metadata +├── README.md # This file +├── langextract_provider_example/ # Package directory +│ ├── __init__.py # Package initialization +│ └── provider.py # Custom provider implementation +└── test_example_provider.py # Test script +``` + +## Key Components + +### Provider Implementation (`provider.py`) + +```python +@lx.providers.registry.register( + r'^gemini', # Pattern for model IDs this provider handles +) +class CustomGeminiProvider(lx.inference.BaseLanguageModel): + def __init__(self, model_id: str, **kwargs): + # Initialize your backend client + + def infer(self, batch_prompts, **kwargs): + # Call your backend API and return results +``` + +### Package Configuration (`pyproject.toml`) + +```toml +[project.entry-points."langextract.providers"] +custom_gemini = "langextract_provider_example:CustomGeminiProvider" +``` + +This entry point allows LangExtract to automatically discover your provider. + +## Installation + +```bash +# Navigate to this example directory first +cd examples/custom_provider_plugin + +# Install in development mode +pip install -e . + +# Test the provider (must be run from this directory) +python test_example_provider.py +``` + +## Usage + +Since this example registers the same pattern as the default Gemini provider, you must explicitly specify it: + +```python +import langextract as lx + +config = lx.factory.ModelConfig( + model_id="gemini-2.5-flash", + provider="CustomGeminiProvider", + provider_kwargs={"api_key": "your-api-key"} +) +model = lx.factory.create_model(config) + +result = lx.extract( + text_or_documents="Your text here", + model=model, + prompt_description="Extract key information" +) +``` + +## Creating Your Own Provider + +1. Copy this example as a starting point +2. Update the provider class name and registration pattern +3. Replace the Gemini implementation with your own backend +4. Update package name in `pyproject.toml` +5. Install and test your plugin + +## License + +Apache License 2.0 diff --git a/examples/custom_provider_plugin/langextract_provider_example/__init__.py b/examples/custom_provider_plugin/langextract_provider_example/__init__.py new file mode 100644 index 00000000..abd57fce --- /dev/null +++ b/examples/custom_provider_plugin/langextract_provider_example/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example custom provider plugin for LangExtract.""" + +from langextract_provider_example.provider import CustomGeminiProvider + +__all__ = ["CustomGeminiProvider"] +__version__ = "0.1.0" diff --git a/examples/custom_provider_plugin/langextract_provider_example/provider.py b/examples/custom_provider_plugin/langextract_provider_example/provider.py new file mode 100644 index 00000000..fb7317a8 --- /dev/null +++ b/examples/custom_provider_plugin/langextract_provider_example/provider.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal example of a custom provider plugin for LangExtract.""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Iterator, Sequence + +import langextract as lx + + +@lx.providers.registry.register( + r'^gemini', # Matches Gemini model IDs (same as default provider) +) +@dataclasses.dataclass(init=False) +class CustomGeminiProvider(lx.inference.BaseLanguageModel): + """Example custom LangExtract provider implementation. + + This demonstrates how to create a custom provider for LangExtract + that can intercept and handle model requests. This example uses + Gemini as the backend, but you would replace this with your own + API or model implementation. + + Note: Since this registers the same pattern as the default Gemini provider, + you must explicitly specify this provider when creating a model: + + config = lx.factory.ModelConfig( + model_id="gemini-2.5-flash", + provider="CustomGeminiProvider" + ) + model = lx.factory.create_model(config) + """ + + model_id: str + api_key: str | None + temperature: float + _client: Any = dataclasses.field(repr=False, compare=False) + + def __init__( + self, + model_id: str = 'gemini-2.5-flash', + api_key: str | None = None, + temperature: float = 0.0, + **kwargs: Any, + ) -> None: + """Initialize the custom provider. + + Args: + model_id: The model ID. + api_key: API key for the service. + temperature: Sampling temperature. + **kwargs: Additional parameters. + """ + # TODO: Replace with your own client initialization + try: + from google import genai # pylint: disable=import-outside-toplevel + except ImportError as e: + raise lx.exceptions.InferenceConfigError( + 'This example requires google-genai package. ' + 'Install with: pip install google-genai' + ) from e + + self.model_id = model_id + self.api_key = api_key + self.temperature = temperature + + # Store any additional kwargs for potential use + self._extra_kwargs = kwargs + + if not self.api_key: + raise lx.exceptions.InferenceConfigError( + 'API key required. Set GEMINI_API_KEY or pass api_key parameter.' + ) + + self._client = genai.Client(api_key=self.api_key) + + super().__init__() + + def infer( + self, batch_prompts: Sequence[str], **kwargs: Any + ) -> Iterator[Sequence[lx.inference.ScoredOutput]]: + """Run inference on a batch of prompts. + + Args: + batch_prompts: Input prompts to process. + **kwargs: Additional generation parameters. + + Yields: + Lists of ScoredOutputs, one per prompt. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + + # Add other parameters if provided + for key in ['max_output_tokens', 'top_p', 'top_k']: + if key in kwargs: + config[key] = kwargs[key] + + for prompt in batch_prompts: + try: + # TODO: Replace this with your own API/model calls + response = self._client.models.generate_content( + model=self.model_id, contents=prompt, config=config + ) + output = response.text.strip() + yield [lx.inference.ScoredOutput(score=1.0, output=output)] + + except Exception as e: + raise lx.exceptions.InferenceRuntimeError( + f'API error: {str(e)}', original=e + ) from e diff --git a/examples/custom_provider_plugin/pyproject.toml b/examples/custom_provider_plugin/pyproject.toml new file mode 100644 index 00000000..bb1e12ed --- /dev/null +++ b/examples/custom_provider_plugin/pyproject.toml @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "langextract-provider-example" # Change to your package name +version = "0.1.0" # Update version for releases +description = "Example custom provider plugin for LangExtract" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "Apache-2.0"} +dependencies = [ + # Uncomment when creating a standalone plugin package: + # "langextract", # Will install latest version + "google-genai>=0.2.0", # Replace with your backend's SDK +] + +# Register the provider with LangExtract's plugin system +[project.entry-points."langextract.providers"] +custom_gemini = "langextract_provider_example:CustomGeminiProvider" + +[tool.setuptools.packages.find] +where = ["."] +include = ["langextract_provider_example*"] diff --git a/examples/custom_provider_plugin/test_example_provider.py b/examples/custom_provider_plugin/test_example_provider.py new file mode 100644 index 00000000..13ef8494 --- /dev/null +++ b/examples/custom_provider_plugin/test_example_provider.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple test for the custom provider plugin.""" + +import os + +# Import the provider to trigger registration with LangExtract +# Note: This manual import is only needed when running without installation. +# After `pip install -e .`, the entry point system handles this automatically. +from langextract_provider_example import CustomGeminiProvider # noqa: F401 + +import langextract as lx + + +def main(): + """Test the custom provider.""" + api_key = os.getenv("GEMINI_API_KEY") or os.getenv("LANGEXTRACT_API_KEY") + + if not api_key: + print("Set GEMINI_API_KEY or LANGEXTRACT_API_KEY to test") + return + + # Create model using explicit provider selection + config = lx.factory.ModelConfig( + model_id="gemini-2.5-flash", + provider="CustomGeminiProvider", + provider_kwargs={"api_key": api_key}, + ) + model = lx.factory.create_model(config) + + print(f"✓ Created {model.__class__.__name__}") + + # Test inference + prompts = ["Say hello"] + results = list(model.infer(prompts)) + + if results and results[0]: + print(f"✓ Inference worked: {results[0][0].output[:50]}...") + else: + print("✗ No response") + + +if __name__ == "__main__": + main() diff --git a/langextract/__init__.py b/langextract/__init__.py index a278a095..c32ce18d 100644 --- a/langextract/__init__.py +++ b/langextract/__init__.py @@ -17,7 +17,6 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -import os from typing import Any, cast, Type, TypeVar import warnings @@ -26,9 +25,11 @@ from langextract import annotation from langextract import data from langextract import exceptions +from langextract import factory from langextract import inference from langextract import io from langextract import prompting +from langextract import providers from langextract import resolver from langextract import schema from langextract import visualization @@ -39,9 +40,11 @@ "annotation", "data", "exceptions", + "factory", "inference", "io", "prompting", + "providers", "resolver", "schema", "visualization", @@ -53,6 +56,10 @@ visualize = visualization.visualize # Load environment variables from .env file +# NOTE: This behavior will be changed to opt-in in v2.0.0 +# Libraries typically should not auto-load .env files, but this is kept +# for backward compatibility. Users can set environment variables directly +# or use python-dotenv explicitly in their own code. dotenv.load_dotenv() @@ -191,23 +198,24 @@ def extract( ): model_schema = schema.GeminiSchema.from_examples(prompt_template.examples) - if not api_key: - api_key = os.environ.get("LANGEXTRACT_API_KEY") - - # Currently only Gemini is supported - if not api_key and language_model_type == inference.GeminiLanguageModel: - raise ValueError( - "API key must be provided for cloud-hosted models via the api_key" - " parameter or the LANGEXTRACT_API_KEY environment variable" - ) + # Handle backward compatibility for language_model_type parameter + if language_model_type != inference.GeminiLanguageModel: + warnings.warn( + "The 'language_model_type' parameter is deprecated and will be removed" + " in a future version. The provider is now automatically selected based" + " on the model_id.", + DeprecationWarning, + stacklevel=2, + ) + # Use factory to create the language model base_lm_kwargs: dict[str, Any] = { "api_key": api_key, - "model_id": model_id, "gemini_schema": model_schema, "format_type": format_type, "temperature": temperature, "model_url": model_url, + "base_url": model_url, # Support both parameter names for Ollama "constraint": schema_constraint, "max_workers": max_workers, } @@ -215,9 +223,15 @@ def extract( # Merge user-provided params which have precedence over defaults. base_lm_kwargs.update(language_model_params or {}) + # Filter out None values filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None} - language_model = language_model_type(**filtered_kwargs) + # Create model using factory + # Providers are loaded lazily by the registry on first resolve + config = factory.ModelConfig( + model_id=model_id, provider_kwargs=filtered_kwargs + ) + language_model = factory.create_model(config) resolver_defaults = { "fence_output": fence_output, diff --git a/langextract/exceptions.py b/langextract/exceptions.py index b3103ab7..1ac90a0a 100644 --- a/langextract/exceptions.py +++ b/langextract/exceptions.py @@ -14,7 +14,14 @@ """Base exceptions for LangExtract.""" -__all__ = ["LangExtractError"] +from __future__ import annotations + +__all__ = [ + "LangExtractError", + "InferenceError", + "InferenceConfigError", + "InferenceRuntimeError", +] class LangExtractError(Exception): @@ -24,3 +31,41 @@ class LangExtractError(Exception): This allows users to catch all LangExtract-specific errors with a single except clause. """ + + +class InferenceError(LangExtractError): + """Base exception for inference-related errors.""" + + +class InferenceConfigError(InferenceError): + """Exception raised for configuration errors. + + This includes missing API keys, invalid model IDs, or other + configuration-related issues that prevent model instantiation. + """ + + +class InferenceRuntimeError(InferenceError): + """Exception raised for runtime inference errors. + + This includes API call failures, network errors, or other issues + that occur during inference execution. + """ + + def __init__( + self, + message: str, + *, + original: BaseException | None = None, + provider: str | None = None, + ) -> None: + """Initialize the runtime error. + + Args: + message: Error message. + original: Original exception from the provider SDK. + provider: Name of the provider that raised the error. + """ + super().__init__(message) + self.original = original + self.provider = provider diff --git a/langextract/factory.py b/langextract/factory.py new file mode 100644 index 00000000..3f3bde65 --- /dev/null +++ b/langextract/factory.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating language model instances. + +This module provides a factory pattern for instantiating language models +based on configuration, with support for environment variable resolution +and provider-specific defaults. +""" + +from __future__ import annotations + +import dataclasses +import os +import typing + +from langextract import exceptions +from langextract import inference +from langextract.providers import registry + + +@dataclasses.dataclass(slots=True, frozen=True) +class ModelConfig: + """Configuration for instantiating a language model provider. + + Attributes: + model_id: The model identifier (e.g., "gemini-2.5-flash", "gpt-4o"). + provider: Optional explicit provider name or class name. Use this to + disambiguate when multiple providers support the same model_id. + provider_kwargs: Optional provider-specific keyword arguments. + """ + + model_id: str | None = None + provider: str | None = None + provider_kwargs: dict[str, typing.Any] = dataclasses.field( + default_factory=dict + ) + + +def _kwargs_with_environment_defaults( + model_id: str, kwargs: dict[str, typing.Any] +) -> dict[str, typing.Any]: + """Add environment-based defaults to provider kwargs. + + Args: + model_id: The model identifier. + kwargs: Existing keyword arguments. + + Returns: + Updated kwargs with environment defaults. + """ + resolved = dict(kwargs) + + if "api_key" not in resolved: + model_lower = model_id.lower() + env_vars_by_provider = { + "gemini": ("GEMINI_API_KEY", "LANGEXTRACT_API_KEY"), + "gpt": ("OPENAI_API_KEY", "LANGEXTRACT_API_KEY"), + } + + for provider_prefix, env_vars in env_vars_by_provider.items(): + if provider_prefix in model_lower: + for env_var in env_vars: + api_key = os.getenv(env_var) + if api_key: + resolved["api_key"] = api_key + break + break + + if "ollama" in model_id.lower() and "base_url" not in resolved: + resolved["base_url"] = os.getenv( + "OLLAMA_BASE_URL", "http://localhost:11434" + ) + + return resolved + + +def create_model(config: ModelConfig) -> inference.BaseLanguageModel: + """Create a language model instance from configuration. + + Args: + config: Model configuration with optional model_id and/or provider. + + Returns: + An instantiated language model provider. + + Raises: + ValueError: If neither model_id nor provider is specified. + ValueError: If no provider is registered for the model_id. + InferenceConfigError: If provider instantiation fails. + """ + if not config.model_id and not config.provider: + raise ValueError("Either model_id or provider must be specified") + + try: + if config.provider: + provider_class = registry.resolve_provider(config.provider) + else: + provider_class = registry.resolve(config.model_id) + except (ModuleNotFoundError, ImportError) as e: + raise exceptions.InferenceConfigError( + "Failed to load provider. " + "This may be due to missing dependencies. " + f"Check that all required packages are installed. Error: {e}" + ) from e + + model_id = config.model_id + + kwargs = _kwargs_with_environment_defaults( + model_id or config.provider or "", config.provider_kwargs + ) + + if model_id: + kwargs["model_id"] = model_id + + try: + return provider_class(**kwargs) + except (ValueError, TypeError) as e: + raise exceptions.InferenceConfigError( + f"Failed to create provider {provider_class.__name__}: {e}" + ) from e + + +def create_model_from_id( + model_id: str | None = None, + provider: str | None = None, + **provider_kwargs: typing.Any, +) -> inference.BaseLanguageModel: + """Convenience function to create a model. + + Args: + model_id: The model identifier (e.g., "gemini-2.5-flash"). + provider: Optional explicit provider name to disambiguate. + **provider_kwargs: Optional provider-specific keyword arguments. + + Returns: + An instantiated language model provider. + """ + config = ModelConfig( + model_id=model_id, provider=provider, provider_kwargs=provider_kwargs + ) + return create_model(config) diff --git a/langextract/inference.py b/langextract/inference.py index c32cbaf4..b43977bc 100644 --- a/langextract/inference.py +++ b/langextract/inference.py @@ -15,18 +15,15 @@ """Simple library for performing language model inference.""" import abc -from collections.abc import Iterator, Mapping, Sequence -import concurrent.futures +from collections.abc import Iterator, Sequence import dataclasses import enum import json import textwrap from typing import Any -from google import genai -import openai -import requests -from typing_extensions import override +from absl import logging +from typing_extensions import deprecated import yaml from langextract import data @@ -44,10 +41,11 @@ class ScoredOutput: output: str | None = None def __str__(self) -> str: + score_str = '-' if self.score is None else f'{self.score:.2f}' if self.output is None: - return f'Score: {self.score:.2f}\nOutput: None' + return f'Score: {score_str}\nOutput: None' formatted_lines = textwrap.indent(self.output, prefix=' ') - return f'Score: {self.score:.2f}\nOutput:\n{formatted_lines}' + return f'Score: {score_str}\nOutput:\n{formatted_lines}' class InferenceOutputError(exceptions.LangExtractError): @@ -91,479 +89,209 @@ def infer( score. """ + def infer_batch( + self, prompts: Sequence[str], batch_size: int = 32 # pylint: disable=unused-argument + ) -> list[list[ScoredOutput]]: + """Batch inference with configurable batch size. -class InferenceType(enum.Enum): - ITERATIVE = 'iterative' - MULTIPROCESS = 'multiprocess' + This is a convenience method that collects all results from infer(). + Args: + prompts: List of prompts to process. + batch_size: Batch size (currently unused, for future optimization). -@dataclasses.dataclass(init=False) -class OllamaLanguageModel(BaseLanguageModel): - """Language model inference class using Ollama based host.""" - - _model: str - _model_url: str - _structured_output_format: str - _constraint: schema.Constraint = dataclasses.field( - default_factory=schema.Constraint, repr=False, compare=False - ) - _extra_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, compare=False - ) - - def __init__( - self, - model_id: str, - model_url: str = _OLLAMA_DEFAULT_MODEL_URL, - structured_output_format: str = 'json', - constraint: schema.Constraint = schema.Constraint(), - **kwargs, - ) -> None: - self._model = model_id - self._model_url = model_url - self._structured_output_format = structured_output_format - self._constraint = constraint - self._extra_kwargs = kwargs or {} - super().__init__(constraint=constraint) + Returns: + List of lists of ScoredOutput objects. + """ + results = [] + for output in self.infer(prompts): + results.append(list(output)) + return results - @override - def infer( - self, batch_prompts: Sequence[str], **kwargs - ) -> Iterator[Sequence[ScoredOutput]]: - for prompt in batch_prompts: - response = self._ollama_query( - prompt=prompt, - model=self._model, - structured_output_format=self._structured_output_format, - model_url=self._model_url, - ) - # No score for Ollama. Default to 1.0 - yield [ScoredOutput(score=1.0, output=response['response'])] - - def _ollama_query( - self, - prompt: str, - model: str = 'gemma2:latest', - temperature: float = 0.8, - seed: int | None = None, - top_k: int | None = None, - max_output_tokens: int | None = None, - structured_output_format: str | None = None, # like 'json' - system: str = '', - raw: bool = False, - model_url: str = _OLLAMA_DEFAULT_MODEL_URL, - timeout: int = 30, # seconds - keep_alive: int = 5 * 60, # if loading, keep model up for 5 minutes. - num_threads: int | None = None, - num_ctx: int = 2048, - ) -> Mapping[str, Any]: - """Sends a prompt to an Ollama model and returns the generated response. - - This function makes an HTTP POST request to the `/api/generate` endpoint of - an Ollama server. It can optionally load the specified model first, generate - a response (with or without streaming), then return a parsed JSON response. + def parse_output(self, output: str) -> Any: + """Parses model output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. Args: - prompt: The text prompt to send to the model. - model: The name of the model to use, e.g. "gemma2:latest". - temperature: Sampling temperature. Higher values produce more diverse - output. - seed: Seed for reproducible generation. If None, random seed is used. - top_k: The top-K parameter for sampling. - max_output_tokens: Maximum tokens to generate. If None, the model's - default is used. - structured_output_format: If set to "json" or a JSON schema dict, requests - structured outputs from the model. See Ollama documentation for details. - system: A system prompt to override any system-level instructions. - raw: If True, bypasses any internal prompt templating; you provide the - entire raw prompt. - model_url: The base URL for the Ollama server, typically - "http://localhost:11434". - timeout: Timeout (in seconds) for the HTTP request. - keep_alive: How long (in seconds) the model remains loaded after - generation completes. - num_threads: Number of CPU threads to use. If None, Ollama uses a default - heuristic. - num_ctx: Number of context tokens allowed. If None, uses model’s default - or config. + output: Raw output string from the model. Returns: - A mapping (dictionary-like) containing the server’s JSON response. For - non-streaming calls, the `"response"` key typically contains the entire - generated text. + Parsed Python object (dict or list). Raises: - ValueError: If the server returns a 404 (model not found) or any non-OK - status code other than 200. Also raised on read timeouts or other request - exceptions. + ValueError: If output cannot be parsed as JSON or YAML. """ - options = {'keep_alive': keep_alive} - if seed: - options['seed'] = seed - if temperature: - options['temperature'] = temperature - if top_k: - options['top_k'] = top_k - if num_threads: - options['num_thread'] = num_threads - if max_output_tokens: - options['num_predict'] = max_output_tokens - if num_ctx: - options['num_ctx'] = num_ctx - model_url = model_url + '/api/generate' - - payload = { - 'model': model, - 'prompt': prompt, - 'system': system, - 'stream': False, - 'raw': raw, - 'format': structured_output_format, - 'options': options, - } + # Check if we have a format_type attribute (providers should set this) + format_type = getattr(self, 'format_type', data.FormatType.JSON) + try: - response = requests.post( - model_url, - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - json=payload, - timeout=timeout, - ) - except requests.exceptions.RequestException as e: - if isinstance(e, requests.exceptions.ReadTimeout): - msg = ( - f'Ollama Model timed out (timeout={timeout},' - f' num_threads={num_threads})' - ) - raise ValueError(msg) from e - raise e - - response.encoding = 'utf-8' - if response.status_code == 200: - return response.json() - if response.status_code == 404: - raise ValueError( - f"Can't find Ollama {model}. Try launching `ollama run {model}`" - ' from command line.' - ) - else: + if format_type == data.FormatType.JSON: + return json.loads(output) + else: + return yaml.safe_load(output) + except Exception as e: raise ValueError( - f'Ollama model failed with status code {response.status_code}.' - ) + f'Failed to parse output as {format_type.name}: {str(e)}' + ) from e -@dataclasses.dataclass(init=False) -class GeminiLanguageModel(BaseLanguageModel): - """Language model inference using Google's Gemini API with structured output.""" - - model_id: str = 'gemini-2.5-flash' - api_key: str | None = None - gemini_schema: schema.GeminiSchema | None = None - format_type: data.FormatType = data.FormatType.JSON - temperature: float = 0.0 - max_workers: int = 10 - _extra_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, compare=False - ) - - def __init__( - self, - model_id: str = 'gemini-2.5-flash', - api_key: str | None = None, - gemini_schema: schema.GeminiSchema | None = None, - format_type: data.FormatType = data.FormatType.JSON, - temperature: float = 0.0, - max_workers: int = 10, - **kwargs, - ) -> None: - """Initialize the Gemini language model. +class InferenceType(enum.Enum): + ITERATIVE = 'iterative' + MULTIPROCESS = 'multiprocess' - Args: - model_id: The Gemini model ID to use. - api_key: API key for Gemini service. - gemini_schema: Optional schema for structured output. - format_type: Output format (JSON or YAML). - temperature: Sampling temperature. - max_workers: Maximum number of parallel API calls. - **kwargs: Ignored extra parameters so callers can pass a superset of - arguments shared across back-ends without raising ``TypeError``. - """ - self.model_id = model_id - self.api_key = api_key - self.gemini_schema = gemini_schema - self.format_type = format_type - self.temperature = temperature - self.max_workers = max_workers - self._extra_kwargs = kwargs or {} - if not self.api_key: - raise ValueError('API key not provided.') +@deprecated( + 'Use langextract.providers.ollama.OllamaLanguageModel instead. ' + 'Will be removed in v2.0.0.' +) +class OllamaLanguageModel(BaseLanguageModel): + """Language model inference class using Ollama based host. - self._client = genai.Client(api_key=self.api_key) + DEPRECATED: Use langextract.providers.ollama.OllamaLanguageModel instead. + This class is kept for backward compatibility only. + """ - super().__init__( - constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + def __init__(self, **kwargs): + """Initialize the Ollama language model (deprecated).""" + logging.warning( + 'OllamaLanguageModel from langextract.inference is deprecated. ' + 'Use langextract.providers.ollama.OllamaLanguageModel instead.' ) - def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: - """Process a single prompt and return a ScoredOutput.""" - try: - if self.gemini_schema: - response_schema = self.gemini_schema.schema_dict - mime_type = ( - 'application/json' - if self.format_type == data.FormatType.JSON - else 'application/yaml' - ) - config['response_mime_type'] = mime_type - config['response_schema'] = response_schema - - response = self._client.models.generate_content( - model=self.model_id, contents=prompt, config=config + # pylint: disable=import-outside-toplevel + from langextract.providers import ollama # Avoid circular import + + # Convert old parameter names to new ones + if 'model' in kwargs: + kwargs['model_id'] = kwargs.pop('model') + + if 'structured_output_format' in kwargs: + format_str = kwargs.pop('structured_output_format') + kwargs['format_type'] = ( + data.FormatType.JSON if format_str == 'json' else data.FormatType.YAML ) - return ScoredOutput(score=1.0, output=response.text) + self._impl = ollama.OllamaLanguageModel(**kwargs) + self._model = self._impl._model + self._model_url = self._impl._model_url + self.format_type = ( + self._impl.format_type + ) # Changed from _structured_output_format + self._constraint = self._impl._constraint + self._extra_kwargs = self._impl._extra_kwargs - except Exception as e: - raise InferenceOutputError(f'Gemini API error: {str(e)}') from e + super().__init__(constraint=self._impl._constraint) + + def _ollama_query(self, **kwargs): + """Backward compatibility method.""" + return self._impl._ollama_query(**kwargs) # pylint: disable=protected-access def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[ScoredOutput]]: - """Runs inference on a list of prompts via Gemini's API. - - Args: - batch_prompts: A list of string prompts. - **kwargs: Additional generation params (temperature, top_p, top_k, etc.) - - Yields: - Lists of ScoredOutputs. - """ - config = { - 'temperature': kwargs.get('temperature', self.temperature), - } - if 'max_output_tokens' in kwargs: - config['max_output_tokens'] = kwargs['max_output_tokens'] - if 'top_p' in kwargs: - config['top_p'] = kwargs['top_p'] - if 'top_k' in kwargs: - config['top_k'] = kwargs['top_k'] - - # Use parallel processing for batches larger than 1 - if len(batch_prompts) > 1 and self.max_workers > 1: - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(self.max_workers, len(batch_prompts)) - ) as executor: - future_to_index = { - executor.submit( - self._process_single_prompt, prompt, config.copy() - ): i - for i, prompt in enumerate(batch_prompts) - } - - results: list[ScoredOutput | None] = [None] * len(batch_prompts) - for future in concurrent.futures.as_completed(future_to_index): - index = future_to_index[future] - try: - results[index] = future.result() - except Exception as e: - raise InferenceOutputError( - f'Parallel inference error: {str(e)}' - ) from e - - for result in results: - if result is None: - raise InferenceOutputError('Failed to process one or more prompts') - yield [result] - else: - # Sequential processing for single prompt or worker - for prompt in batch_prompts: - result = self._process_single_prompt(prompt, config.copy()) - yield [result] + """Delegate to new provider.""" + return self._impl.infer(batch_prompts, **kwargs) def parse_output(self, output: str) -> Any: - """Parses Gemini output as JSON or YAML. + """Delegate to new provider.""" + return self._impl.parse_output(output) - Note: This expects raw JSON/YAML without code fences. - Code fence extraction is handled by resolver.py. - """ - try: - if self.format_type == data.FormatType.JSON: - return json.loads(output) - else: - return yaml.safe_load(output) - except Exception as e: - raise ValueError( - f'Failed to parse output as {self.format_type.name}: {str(e)}' - ) from e +@deprecated( + 'Use langextract.providers.gemini.GeminiLanguageModel instead. ' + 'Will be removed in v2.0.0.' +) +class GeminiLanguageModel(BaseLanguageModel): + """Language model inference using Google's Gemini API with structured output. -@dataclasses.dataclass(init=False) -class OpenAILanguageModel(BaseLanguageModel): - """Language model inference using OpenAI's API with structured output.""" - - model_id: str = 'gpt-4o-mini' - api_key: str | None = None - base_url: str | None = None - organization: str | None = None - format_type: data.FormatType = data.FormatType.JSON - temperature: float = 0.0 - max_workers: int = 10 - _client: openai.OpenAI | None = dataclasses.field( - default=None, repr=False, compare=False - ) - _extra_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, compare=False - ) - - def __init__( - self, - model_id: str = 'gpt-4o-mini', - api_key: str | None = None, - base_url: str | None = None, - organization: str | None = None, - format_type: data.FormatType = data.FormatType.JSON, - temperature: float = 0.0, - max_workers: int = 10, - **kwargs, - ) -> None: - """Initialize the OpenAI language model. - - Args: - model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). - api_key: API key for OpenAI service. - base_url: Base URL for OpenAI service. - organization: Optional OpenAI organization ID. - format_type: Output format (JSON or YAML). - temperature: Sampling temperature. - max_workers: Maximum number of parallel API calls. - **kwargs: Ignored extra parameters so callers can pass a superset of - arguments shared across back-ends without raising ``TypeError``. - """ - self.model_id = model_id - self.api_key = api_key - self.base_url = base_url - self.organization = organization - self.format_type = format_type - self.temperature = temperature - self.max_workers = max_workers - self._extra_kwargs = kwargs or {} - - if not self.api_key: - raise ValueError('API key not provided.') - - # Initialize the OpenAI client - self._client = openai.OpenAI( - api_key=self.api_key, - base_url=self.base_url, - organization=self.organization, - ) + DEPRECATED: Use langextract.providers.gemini.GeminiLanguageModel instead. + This class is kept for backward compatibility only. + """ - super().__init__( - constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + def __init__(self, **kwargs): + """Initialize the Gemini language model (deprecated).""" + logging.warning( + 'GeminiLanguageModel from langextract.inference is deprecated. ' + 'Use langextract.providers.gemini.GeminiLanguageModel instead.' ) - def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: - """Process a single prompt and return a ScoredOutput.""" - try: - # Prepare the system message for structured output - system_message = '' - if self.format_type == data.FormatType.JSON: - system_message = ( - 'You are a helpful assistant that responds in JSON format.' - ) - elif self.format_type == data.FormatType.YAML: - system_message = ( - 'You are a helpful assistant that responds in YAML format.' - ) - - # Create the chat completion using the v1.x client API - response = self._client.chat.completions.create( - model=self.model_id, - messages=[ - {'role': 'system', 'content': system_message}, - {'role': 'user', 'content': prompt}, - ], - temperature=config.get('temperature', self.temperature), - max_tokens=config.get('max_output_tokens'), - top_p=config.get('top_p'), - n=1, - ) + # pylint: disable=import-outside-toplevel + from langextract.providers import gemini # Avoid circular import - # Extract the response text using the v1.x response format - output_text = response.choices[0].message.content + self._impl = gemini.GeminiLanguageModel(**kwargs) + self.model_id = self._impl.model_id + self.api_key = self._impl.api_key + self.gemini_schema = self._impl.gemini_schema + self.format_type = self._impl.format_type + self.temperature = self._impl.temperature + self.max_workers = self._impl.max_workers - return ScoredOutput(score=1.0, output=output_text) - - except Exception as e: - raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e + super().__init__(constraint=self._impl._constraint) def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[ScoredOutput]]: - """Runs inference on a list of prompts via OpenAI's API. + """Delegate to new provider.""" + return self._impl.infer(batch_prompts, **kwargs) - Args: - batch_prompts: A list of string prompts. - **kwargs: Additional generation params (temperature, top_p, etc.) + def parse_output(self, output: str) -> Any: + """Delegate to new provider.""" + return self._impl.parse_output(output) - Yields: - Lists of ScoredOutputs. - """ - config = { - 'temperature': kwargs.get('temperature', self.temperature), - } - if 'max_output_tokens' in kwargs: - config['max_output_tokens'] = kwargs['max_output_tokens'] - if 'top_p' in kwargs: - config['top_p'] = kwargs['top_p'] - - # Use parallel processing for batches larger than 1 - if len(batch_prompts) > 1 and self.max_workers > 1: - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(self.max_workers, len(batch_prompts)) - ) as executor: - future_to_index = { - executor.submit( - self._process_single_prompt, prompt, config.copy() - ): i - for i, prompt in enumerate(batch_prompts) - } - - results: list[ScoredOutput | None] = [None] * len(batch_prompts) - for future in concurrent.futures.as_completed(future_to_index): - index = future_to_index[future] - try: - results[index] = future.result() - except Exception as e: - raise InferenceOutputError( - f'Parallel inference error: {str(e)}' - ) from e - - for result in results: - if result is None: - raise InferenceOutputError('Failed to process one or more prompts') - yield [result] - else: - # Sequential processing for single prompt or worker - for prompt in batch_prompts: - result = self._process_single_prompt(prompt, config.copy()) - yield [result] - def parse_output(self, output: str) -> Any: - """Parses OpenAI output as JSON or YAML. +@deprecated( + 'Use langextract.providers.openai.OpenAILanguageModel instead. ' + 'Will be removed in v2.0.0.' +) +class OpenAILanguageModel(BaseLanguageModel): # pylint: disable=too-many-instance-attributes + """Language model inference using OpenAI's API with structured output. + + DEPRECATED: Use langextract.providers.openai.OpenAILanguageModel instead. + This class is kept for backward compatibility only. + """ + + def __init__(self, **kwargs): + """Initialize the OpenAI language model (deprecated).""" + logging.warning( + 'OpenAILanguageModel from langextract.inference is deprecated. ' + 'Use langextract.providers.openai.OpenAILanguageModel instead.' + ) + + # pylint: disable=import-outside-toplevel + from langextract.providers import openai # Avoid circular import - Note: This expects raw JSON/YAML without code fences. - Code fence extraction is handled by resolver.py. - """ try: - if self.format_type == data.FormatType.JSON: - return json.loads(output) - else: - return yaml.safe_load(output) - except Exception as e: + self._impl = openai.OpenAILanguageModel(**kwargs) + except exceptions.InferenceConfigError as e: + # Convert to ValueError for backward compatibility raise ValueError( - f'Failed to parse output as {self.format_type.name}: {str(e)}' + str(e).replace( + 'API key not provided for OpenAI.', 'API key not provided.' + ) ) from e + self.model_id = self._impl.model_id + self.api_key = self._impl.api_key + self.base_url = self._impl.base_url + self.organization = self._impl.organization + self.format_type = self._impl.format_type + self.temperature = self._impl.temperature + self.max_workers = self._impl.max_workers + self._client = self._impl._client + + self._process_single_prompt = ( + self._impl._process_single_prompt + ) # For test compatibility + + super().__init__(constraint=self._impl._constraint) + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[ScoredOutput]]: + """Delegate to new provider.""" + return self._impl.infer(batch_prompts, **kwargs) + + def parse_output(self, output: str) -> Any: + """Delegate to new provider.""" + return self._impl.parse_output(output) diff --git a/langextract/providers/README.md b/langextract/providers/README.md new file mode 100644 index 00000000..1395fb38 --- /dev/null +++ b/langextract/providers/README.md @@ -0,0 +1,372 @@ +# LangExtract Provider System + +This directory contains the provider system for LangExtract, which enables support for different Large Language Model (LLM) backends. + +## Architecture Overview + +The provider system uses a **registry pattern** with **automatic discovery**: + +1. **Registry** (`registry.py`): Maps model ID patterns to provider classes +2. **Factory** (`../factory.py`): Creates provider instances based on model IDs +3. **Providers**: Implement the `BaseLanguageModel` interface + +### Provider Resolution Flow + +``` +User Code LangExtract Provider +───────── ─────────── ──────── + | | | + | lx.extract( | | + | model_id="gemini-2.5-flash") | + |─────────────────────────────> | + | | | + | factory.create_model() | + | | | + | registry.resolve("gemini-2.5-flash") | + | Pattern match: ^gemini | + | ↓ | + | GeminiLanguageModel | + | | | + | Instantiate provider | + | |─────────────────────────────>| + | | | + | | Provider API calls | + | |<─────────────────────────────| + | | | + |<──────────────────────────── | + | AnnotatedDocument | | +``` + +### Explicit Provider Selection + +When multiple providers might support the same model ID, or when you want to use a specific provider, you can explicitly specify the provider: + +```python +import langextract as lx + +# Method 1: Using factory directly with provider parameter +config = lx.factory.ModelConfig( + model_id="gpt-4", + provider="OpenAILanguageModel", # Explicit provider + provider_kwargs={"api_key": "..."} +) +model = lx.factory.create_model(config) + +# Method 2: Using provider without model_id (uses provider's default) +config = lx.factory.ModelConfig( + provider="GeminiLanguageModel", # Will use default gemini-2.5-flash + provider_kwargs={"api_key": "..."} +) +model = lx.factory.create_model(config) + +# Method 3: Auto-detection (when no conflicts exist) +config = lx.factory.ModelConfig( + model_id="gemini-2.5-flash" # Provider auto-detected +) +model = lx.factory.create_model(config) +``` + +Provider names can be: +- Full class name: `"GeminiLanguageModel"`, `"OpenAILanguageModel"`, `"OllamaLanguageModel"` +- Partial match: `"gemini"`, `"openai"`, `"ollama"` (case-insensitive) + +## Provider Types + +### 1. Core Providers (Always Available) +Ships with langextract, dependencies included: +- **Gemini** (`gemini.py`): Google's Gemini models +- **Ollama** (`ollama.py`): Local models via Ollama + +### 2. Built-in Provider with Optional Dependencies +Ships with langextract, but requires extra installation: +- **OpenAI** (`openai.py`): OpenAI's GPT models + - Code included in package + - Requires: `pip install langextract[openai]` to install OpenAI SDK + - Future: May be moved to external plugin package + +### 3. External Plugins (Third-party) +Separate packages that extend LangExtract with new providers: +- **Installed separately**: `pip install langextract-yourprovider` +- **Auto-discovered**: Uses Python entry points for automatic registration +- **Zero configuration**: Import langextract and the provider is available +- **Independent updates**: Update providers without touching core + +```python +# Install a third-party provider +pip install langextract-yourprovider + +# Use it immediately - no imports needed! +import langextract as lx +result = lx.extract( + text="...", + model_id="yourmodel-latest" # Automatically finds the provider +) +``` + +#### How Plugin Discovery Works + +``` +1. pip install langextract-yourprovider + └── Installs package containing: + • Provider class with @lx.providers.registry.register decorator + • Python entry point pointing to this class + +2. import langextract + └── Loads providers/__init__.py + └── Discovers and imports plugin via entry points + └── @lx.providers.registry.register decorator fires + └── Provider patterns added to registry + +3. lx.extract(model_id="yourmodel-latest") + └── Registry matches pattern and uses your provider +``` + +## How Provider Selection Works + +When you call `lx.extract(model_id="gemini-2.5-flash", ...)`, here's what happens: + +1. **Factory receives model_id**: "gemini-2.5-flash" +2. **Registry searches patterns**: Each provider registers regex patterns +3. **First match wins**: Returns the matching provider class +4. **Provider instantiated**: With model_id and any kwargs +5. **Inference runs**: Using the selected provider + +### Pattern Registration Example + +```python +import langextract as lx + +# Gemini provider registration: +@lx.providers.registry.register( + r'^GeminiLanguageModel$', # Explicit: model_id="GeminiLanguageModel" + r'^gemini', # Prefix: model_id="gemini-2.5-flash" + r'^palm' # Legacy: model_id="palm-2" +) +class GeminiLanguageModel(lx.inference.BaseLanguageModel): + def __init__(self, model_id: str, api_key: str = None, **kwargs): + # Initialize Gemini client + ... + + def infer(self, batch_prompts, **kwargs): + # Call Gemini API + ... +``` + +## Usage Examples + +### Using Default Provider Selection +```python +import langextract as lx + +# Automatically selects Gemini provider +result = lx.extract( + text="...", + model_id="gemini-2.5-flash" +) +``` + +### Passing Parameters to Providers + +Parameters flow from `lx.extract()` to providers through several mechanisms: + +```python +# 1. Common parameters handled by lx.extract itself: +result = lx.extract( + text="Your document", + model_id="gemini-2.5-flash", + prompt_description="Extract key facts", + examples=[...], # Used for few-shot prompting + num_workers=4, # Parallel processing + max_chunk_size=3000, # Document chunking +) + +# 2. Provider-specific parameters passed via **kwargs: +result = lx.extract( + text="Your document", + model_id="gemini-2.5-flash", + prompt_description="Extract entities", + # These go directly to the Gemini provider: + temperature=0.7, # Sampling temperature + api_key="your-key", # Override environment variable + max_output_tokens=1000, # Token limit +) +``` + +### Using the Factory for Advanced Control +```python +# When you need explicit provider selection or advanced configuration +from langextract import factory + +# Specify both model and provider (useful when multiple providers support same model) +config = factory.ModelConfig( + model_id="llama3.2:1b", + provider="OllamaLanguageModel", # Explicitly use Ollama + provider_kwargs={ + "model_url": "http://localhost:11434" + } +) +model = factory.create_model(config) +``` + +### Direct Provider Usage +```python +import langextract as lx + +# Direct import if you prefer (optional) +from langextract.providers.gemini import GeminiLanguageModel + +model = GeminiLanguageModel( + model_id="gemini-2.5-flash", + api_key="your-key" +) +outputs = model.infer(["prompt1", "prompt2"]) +``` + +## Creating a New Provider + +**📁 Complete Example**: See [examples/custom_provider_plugin/](../../examples/custom_provider_plugin/) for a fully-functional plugin template with testing and documentation. + +### Option 1: External Plugin (Recommended) + +External plugins are the recommended approach for adding new providers. They're easy to maintain, distribute, and don't require changes to the core package. + +#### For Users (Installing an External Plugin) +Simply install the plugin package: +```bash +pip install langextract-yourprovider +# That's it! The provider is now available in langextract +``` + +#### For Developers (Creating an External Plugin) + +1. Create a new package: +``` +langextract-myprovider/ +├── pyproject.toml +├── README.md +└── langextract_myprovider/ + └── __init__.py +``` + +2. Configure entry point in `pyproject.toml`: +```toml +[project] +name = "langextract-myprovider" +dependencies = ["langextract>=1.0.0", "your-sdk"] + +[project.entry-points."langextract.providers"] +# Pattern 1: Register the class directly +myprovider = "langextract_myprovider:MyProviderLanguageModel" + +# Pattern 2: Register a module that self-registers +# myprovider = "langextract_myprovider" +``` + +3. Implement your provider: +```python +# langextract_myprovider/__init__.py +import langextract as lx + +@lx.providers.registry.register(r'^mymodel', r'^custom') +class MyProviderLanguageModel(lx.inference.BaseLanguageModel): + def __init__(self, model_id: str, **kwargs): + super().__init__() + self.model_id = model_id + # Initialize your client + + def infer(self, batch_prompts, **kwargs): + # Implement inference + for prompt in batch_prompts: + result = self._call_api(prompt) + yield [lx.inference.ScoredOutput(score=1.0, output=result)] +``` + +**Pattern Registration Explained:** +- The `@register` decorator patterns (e.g., `r'^mymodel'`, `r'^custom'`) define which model IDs your provider supports +- When users call `lx.extract(model_id="mymodel-3b")`, the registry matches against these patterns +- Your provider will handle any model_id starting with "mymodel" or "custom" +- Users can explicitly select your provider using its class name: + ```python + config = lx.factory.ModelConfig(provider="MyProviderLanguageModel") + # Or partial match: provider="myprovider" (matches class name) + +4. Publish your package to PyPI: +```bash +pip install build twine +python -m build +twine upload dist/* +``` + +Now users can install and use your provider with just `pip install langextract-myprovider`! + +### Option 2: Built-in Provider (Requires Core Team Approval) + +**⚠️ Note**: Adding a provider to the core package requires: +- Significant community demand and support +- Commitment to long-term maintenance +- Approval from the LangExtract maintainers +- A pull request to the main repository + +This approach should only be used for providers that benefit a large portion of the user base. + +1. Create your provider file: +```python +# langextract/providers/myprovider.py +import langextract as lx + +@lx.providers.registry.register(r'^mymodel', r'^custom') +class MyProviderLanguageModel(lx.inference.BaseLanguageModel): + # Implementation same as above +``` + +2. Import it in `providers/__init__.py`: +```python +# In langextract/providers/__init__.py +from langextract.providers import myprovider # noqa: F401 +``` + +3. Submit a pull request with: + - Provider implementation + - Comprehensive tests + - Documentation + - Justification for inclusion in core + +## Environment Variables + +The factory automatically resolves API keys from environment: + +| Provider | Environment Variables (in priority order) | +|----------|------------------------------------------| +| Gemini | `GEMINI_API_KEY`, `LANGEXTRACT_API_KEY` | +| OpenAI | `OPENAI_API_KEY`, `LANGEXTRACT_API_KEY` | +| Ollama | `OLLAMA_BASE_URL` (default: http://localhost:11434) | + +## Design Principles + +1. **Zero Configuration**: Providers auto-register when imported +2. **Extensible**: Easy to add new providers without modifying core +3. **Lazy Loading**: Optional dependencies only loaded when needed +4. **Explicit Control**: Users can force specific providers when needed +5. **Pattern Priority**: All patterns have equal priority (0) by default + +## Migration Path for OpenAI + +Currently, OpenAI is an optional built-in provider. Future plan: +1. Move to external plugin package (`langextract-openai`) +2. Users install via `pip install langextract-openai` +3. Usage remains exactly the same +4. Benefits: Cleaner dependencies, better modularity + +## Common Issues + +### Provider Not Found +```python +ValueError: No provider registered for model_id='unknown-model' +``` +**Solution**: Check available patterns with `registry.list_entries()` + +### Missing Dependencies +```python +InferenceConfigError: OpenAI provider requires openai package +``` +**Solution**: Install optional dependencies: `pip install langextract[openai]` diff --git a/langextract/providers/__init__.py b/langextract/providers/__init__.py new file mode 100644 index 00000000..d4b65a81 --- /dev/null +++ b/langextract/providers/__init__.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provider package for LangExtract. + +This package contains the registry system and provider implementations +for various LLM backends. +""" +# pylint: disable=cyclic-import + +from importlib import metadata +import os + +from absl import logging + +from langextract.providers import registry + +# Track whether plugins have been loaded +_PLUGINS_LOADED = False + + +def load_plugins_once() -> None: + """Load third-party providers via entry points. + + This function is idempotent and will only load plugins once. + Set LANGEXTRACT_DISABLE_PLUGINS=1 to disable plugin loading. + """ + global _PLUGINS_LOADED # pylint: disable=global-statement + if _PLUGINS_LOADED: + return + + if os.getenv("LANGEXTRACT_DISABLE_PLUGINS") == "1": + logging.info("Plugin loading disabled by LANGEXTRACT_DISABLE_PLUGINS=1") + _PLUGINS_LOADED = True + return + + _PLUGINS_LOADED = True + + try: + entry_points_group = metadata.entry_points(group="langextract.providers") + except Exception as exc: + logging.debug("No third-party provider entry points found: %s", exc) + return + + for entry_point in entry_points_group: + try: + provider = entry_point.load() + + if isinstance(provider, type): + registry.register(entry_point.name)(provider) + logging.info( + "Registered third-party provider from entry point: %s", + entry_point.name, + ) + else: + logging.debug( + "Loaded provider module from entry point: %s", entry_point.name + ) + except Exception as exc: + logging.warning( + "Failed to load third-party provider %s: %s", entry_point.value, exc + ) + + +# pylint: disable=wrong-import-position +from langextract.providers import gemini # noqa: F401 +from langextract.providers import ollama # noqa: F401 + +try: + from langextract.providers import openai # noqa: F401 +except ImportError: + pass + +__all__ = ["registry", "load_plugins_once"] diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py new file mode 100644 index 00000000..09ce2ec8 --- /dev/null +++ b/langextract/providers/gemini.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemini provider for LangExtract.""" +# pylint: disable=cyclic-import,duplicate-code + +from __future__ import annotations + +import concurrent.futures +import dataclasses +from typing import Any, Iterator, Sequence + +from langextract import data +from langextract import exceptions +from langextract import inference +from langextract import schema +from langextract.providers import registry + + +@registry.register( + r'^gemini', # gemini-2.5-flash, gemini-2.5-pro, etc. + priority=10, +) +@dataclasses.dataclass(init=False) +class GeminiLanguageModel(inference.BaseLanguageModel): + """Language model inference using Google's Gemini API with structured output.""" + + model_id: str = 'gemini-2.5-flash' + api_key: str | None = None + gemini_schema: schema.GeminiSchema | None = None + format_type: data.FormatType = data.FormatType.JSON + temperature: float = 0.0 + max_workers: int = 10 + fence_output: bool = False + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str = 'gemini-2.5-flash', + api_key: str | None = None, + gemini_schema: schema.GeminiSchema | None = None, + format_type: data.FormatType = data.FormatType.JSON, + temperature: float = 0.0, + max_workers: int = 10, + fence_output: bool = False, + **kwargs, + ) -> None: + """Initialize the Gemini language model. + + Args: + model_id: The Gemini model ID to use. + api_key: API key for Gemini service. + gemini_schema: Optional schema for structured output. + format_type: Output format (JSON or YAML). + temperature: Sampling temperature. + max_workers: Maximum number of parallel API calls. + fence_output: Whether to wrap output in markdown fences (ignored, + Gemini handles this based on schema). + **kwargs: Ignored extra parameters so callers can pass a superset of + arguments shared across back-ends without raising ``TypeError``. + """ + try: + # pylint: disable=import-outside-toplevel + from google import genai + except ImportError as e: + raise exceptions.InferenceConfigError( + 'Failed to import google-genai. Reinstall: pip install langextract' + ) from e + + self.model_id = model_id + self.api_key = api_key + self.gemini_schema = gemini_schema + self.format_type = format_type + self.temperature = temperature + self.max_workers = max_workers + self.fence_output = ( + fence_output # Store but may not use depending on schema + ) + self._extra_kwargs = kwargs or {} + + if not self.api_key: + raise exceptions.InferenceConfigError('API key not provided for Gemini.') + + self._client = genai.Client(api_key=self.api_key) + + super().__init__( + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + ) + + def _process_single_prompt( + self, prompt: str, config: dict + ) -> inference.ScoredOutput: + """Process a single prompt and return a ScoredOutput.""" + try: + if self.gemini_schema: + response_schema = self.gemini_schema.schema_dict + mime_type = ( + 'application/json' + if self.format_type == data.FormatType.JSON + else 'application/yaml' + ) + config['response_mime_type'] = mime_type + config['response_schema'] = response_schema + + response = self._client.models.generate_content( + model=self.model_id, contents=prompt, config=config # type: ignore[arg-type] + ) + + return inference.ScoredOutput(score=1.0, output=response.text) + + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'Gemini API error: {str(e)}', original=e + ) from e + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[inference.ScoredOutput]]: + """Runs inference on a list of prompts via Gemini's API. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params (temperature, top_p, top_k, etc.) + + Yields: + Lists of ScoredOutputs. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + if 'max_output_tokens' in kwargs: + config['max_output_tokens'] = kwargs['max_output_tokens'] + if 'top_p' in kwargs: + config['top_p'] = kwargs['top_p'] + if 'top_k' in kwargs: + config['top_k'] = kwargs['top_k'] + + # Use parallel processing for batches larger than 1 + if len(batch_prompts) > 1 and self.max_workers > 1: + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(self.max_workers, len(batch_prompts)) + ) as executor: + future_to_index = { + executor.submit( + self._process_single_prompt, prompt, config.copy() + ): i + for i, prompt in enumerate(batch_prompts) + } + + results: list[inference.ScoredOutput | None] = [None] * len( + batch_prompts + ) + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'Parallel inference error: {str(e)}', original=e + ) from e + + for result in results: + if result is None: + raise exceptions.InferenceRuntimeError( + 'Failed to process one or more prompts' + ) + yield [result] + else: + # Sequential processing for single prompt or worker + for prompt in batch_prompts: + result = self._process_single_prompt(prompt, config.copy()) + yield [result] diff --git a/langextract/providers/ollama.py b/langextract/providers/ollama.py new file mode 100644 index 00000000..dddf2d55 --- /dev/null +++ b/langextract/providers/ollama.py @@ -0,0 +1,271 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ollama provider for LangExtract.""" +# pylint: disable=cyclic-import,duplicate-code + +from __future__ import annotations + +import dataclasses +from typing import Any, Iterator, Mapping, Sequence +import warnings + +import requests + +from langextract import data +from langextract import exceptions +from langextract import inference +from langextract import schema +from langextract.providers import registry + +_OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434' + + +@registry.register( + # Latest open models via Ollama (2024-2025) + r'^gemma', # gemma2:2b, gemma2:9b, gemma2:27b, etc. + r'^llama', # llama3.2:1b, llama3.2:3b, llama3.1:8b, llama3.1:70b, etc. + r'^mistral', # mistral:7b, mistral-nemo:12b, mistral-large, etc. + r'^mixtral', # mixtral:8x7b, mixtral:8x22b, etc. + r'^phi', # phi3:mini, phi3:medium, phi3.5, etc. + r'^qwen', # qwen3:8b, qwen2.5:7b, qwen2.5:32b, qwen2.5-coder, etc. + r'^deepseek', # deepseek-r1:8b, deepseek-v3:671b, deepseek-coder-v2, etc. + r'^command-r', # command-r:35b, command-r-plus:104b, etc. + r'^starcoder', # starcoder2:3b, starcoder2:7b, starcoder2:15b, etc. + r'^codellama', # codellama:7b, codellama:13b, codellama:34b, etc. + r'^codegemma', # codegemma:2b, codegemma:7b, etc. + r'^tinyllama', # tinyllama:1.1b, etc. + r'^wizardcoder', # wizardcoder:7b, wizardcoder:13b, wizardcoder:34b, etc. + priority=10, +) +@dataclasses.dataclass(init=False) +class OllamaLanguageModel(inference.BaseLanguageModel): + """Language model inference class using Ollama based host.""" + + _model: str + _model_url: str + format_type: data.FormatType = data.FormatType.JSON + _constraint: schema.Constraint = dataclasses.field( + default_factory=schema.Constraint, repr=False, compare=False + ) + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str, + model_url: str = _OLLAMA_DEFAULT_MODEL_URL, + base_url: str | None = None, # Support both model_url and base_url + format_type: data.FormatType | None = None, + structured_output_format: str | None = None, # Deprecated parameter + constraint: schema.Constraint = schema.Constraint(), + **kwargs, + ) -> None: + """Initialize the Ollama language model. + + Args: + model_id: The Ollama model ID to use. + model_url: URL for Ollama server (legacy parameter). + base_url: Alternative parameter name for Ollama server URL. + format_type: Output format (JSON or YAML). Defaults to JSON. + structured_output_format: DEPRECATED - use format_type instead. + constraint: Schema constraints. + **kwargs: Additional parameters. + """ + self._requests = requests + + # Handle deprecated structured_output_format parameter + if structured_output_format is not None: + warnings.warn( + "The 'structured_output_format' parameter is deprecated and will be" + " removed in v2.0.0. Use 'format_type' instead with" + ' data.FormatType.JSON or data.FormatType.YAML.', + DeprecationWarning, + stacklevel=2, + ) + # Only use structured_output_format if format_type wasn't explicitly provided + if format_type is None: + format_type = ( + data.FormatType.JSON + if structured_output_format == 'json' + else data.FormatType.YAML + ) + + # Default to JSON if neither parameter was provided + if format_type is None: + format_type = data.FormatType.JSON + + self._model = model_id + # Support both model_url and base_url parameters + self._model_url = base_url or model_url or _OLLAMA_DEFAULT_MODEL_URL + self.format_type = format_type + self._constraint = constraint + self._extra_kwargs = kwargs or {} + super().__init__(constraint=constraint) + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[inference.ScoredOutput]]: + """Runs inference on a list of prompts via Ollama's API. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params. + + Yields: + Lists of ScoredOutputs. + """ + for prompt in batch_prompts: + try: + response = self._ollama_query( + prompt=prompt, + model=self._model, + structured_output_format='json' + if self.format_type == data.FormatType.JSON + else 'yaml', + model_url=self._model_url, + **kwargs, + ) + # No score for Ollama. Default to 1.0 + yield [inference.ScoredOutput(score=1.0, output=response['response'])] + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'Ollama API error: {str(e)}', original=e + ) from e + + def _ollama_query( + self, + prompt: str, + model: str | None = None, + temperature: float = 0.8, + seed: int | None = None, + top_k: int | None = None, + max_output_tokens: int | None = None, + structured_output_format: str | None = None, + system: str = '', + raw: bool = False, + model_url: str | None = None, + timeout: int = 30, + keep_alive: int = 5 * 60, + num_threads: int | None = None, + num_ctx: int = 2048, + ) -> Mapping[str, Any]: + """Sends a prompt to an Ollama model and returns the generated response. + + This function makes an HTTP POST request to the `/api/generate` endpoint of + an Ollama server. It can optionally load the specified model first, generate + a response (with or without streaming), then return a parsed JSON response. + + Args: + prompt: The text prompt to send to the model. + model: The name of the model to use. Defaults to self._model. + temperature: Sampling temperature. Higher values produce more diverse + output. + seed: Seed for reproducible generation. If None, random seed is used. + top_k: The top-K parameter for sampling. + max_output_tokens: Maximum tokens to generate. If None, the model's + default is used. + structured_output_format: If set to "json" or a JSON schema dict, requests + structured outputs from the model. See Ollama documentation for details. + system: A system prompt to override any system-level instructions. + raw: If True, bypasses any internal prompt templating; you provide the + entire raw prompt. + model_url: The base URL for the Ollama server. Defaults to self._model_url. + timeout: Timeout (in seconds) for the HTTP request. + keep_alive: How long (in seconds) the model remains loaded after + generation completes. + num_threads: Number of CPU threads to use. If None, Ollama uses a default + heuristic. + num_ctx: Number of context tokens allowed. If None, uses model's default + or config. + **kwargs: Additional parameters passed through. + + Returns: + A mapping (dictionary-like) containing the server's JSON response. For + non-streaming calls, the `"response"` key typically contains the entire + generated text. + + Raises: + InferenceConfigError: If the server returns a 404 (model not found). + InferenceRuntimeError: For any other HTTP errors, timeouts, or request + exceptions. + """ + model = model or self._model + model_url = model_url or self._model_url + if structured_output_format is None: + structured_output_format = ( + 'json' if self.format_type == data.FormatType.JSON else 'yaml' + ) + + options: dict[str, Any] = {'keep_alive': keep_alive} + if seed: + options['seed'] = seed + if temperature: + options['temperature'] = temperature + if top_k: + options['top_k'] = top_k + if num_threads: + options['num_thread'] = num_threads + if max_output_tokens: + options['num_predict'] = max_output_tokens + if num_ctx: + options['num_ctx'] = num_ctx + + api_url = model_url + '/api/generate' + + payload = { + 'model': model, + 'prompt': prompt, + 'system': system, + 'stream': False, + 'raw': raw, + 'format': structured_output_format, + 'options': options, + } + + try: + response = self._requests.post( + api_url, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + json=payload, + timeout=timeout, + ) + except self._requests.exceptions.RequestException as e: + if isinstance(e, self._requests.exceptions.ReadTimeout): + msg = ( + f'Ollama Model timed out (timeout={timeout},' + f' num_threads={num_threads})' + ) + raise exceptions.InferenceRuntimeError( + msg, original=e, provider='Ollama' + ) from e + raise exceptions.InferenceRuntimeError( + f'Ollama request failed: {str(e)}', original=e, provider='Ollama' + ) from e + + response.encoding = 'utf-8' + if response.status_code == 200: + return response.json() + if response.status_code == 404: + raise exceptions.InferenceConfigError( + f"Can't find Ollama {model}. Try launching `ollama run {model}`" + ' from command line.' + ) + else: + msg = f'Bad status code from Ollama: {response.status_code}' + raise exceptions.InferenceRuntimeError(msg, provider='Ollama') diff --git a/langextract/providers/openai.py b/langextract/providers/openai.py new file mode 100644 index 00000000..84c35afb --- /dev/null +++ b/langextract/providers/openai.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI provider for LangExtract.""" +# pylint: disable=cyclic-import,duplicate-code + +from __future__ import annotations + +import concurrent.futures +import dataclasses +from typing import Any, Iterator, Sequence + +from langextract import data +from langextract import exceptions +from langextract import inference +from langextract import schema +from langextract.providers import registry + + +@registry.register( + r'^gpt-4', # gpt-4.1, gpt-4o, gpt-4-turbo, etc. + r'^gpt4\.', # gpt4.1-mini, gpt4.1-nano, etc. + priority=10, +) +@dataclasses.dataclass(init=False) +class OpenAILanguageModel(inference.BaseLanguageModel): + """Language model inference using OpenAI's API with structured output.""" + + model_id: str = 'gpt-4o-mini' + api_key: str | None = None + base_url: str | None = None + organization: str | None = None + format_type: data.FormatType = data.FormatType.JSON + temperature: float = 0.0 + max_workers: int = 10 + _client: Any = dataclasses.field(default=None, repr=False, compare=False) + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str = 'gpt-4o-mini', + api_key: str | None = None, + base_url: str | None = None, + organization: str | None = None, + format_type: data.FormatType = data.FormatType.JSON, + temperature: float = 0.0, + max_workers: int = 10, + **kwargs, + ) -> None: + """Initialize the OpenAI language model. + + Args: + model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). + api_key: API key for OpenAI service. + base_url: Base URL for OpenAI service. + organization: Optional OpenAI organization ID. + format_type: Output format (JSON or YAML). + temperature: Sampling temperature. + max_workers: Maximum number of parallel API calls. + **kwargs: Ignored extra parameters so callers can pass a superset of + arguments shared across back-ends without raising ``TypeError``. + """ + # Lazy import: OpenAI package required + try: + # pylint: disable=import-outside-toplevel + import openai + except ImportError as e: + raise exceptions.InferenceConfigError( + 'OpenAI provider requires openai package. ' + 'Install with: pip install langextract[openai]' + ) from e + + self.model_id = model_id + self.api_key = api_key + self.base_url = base_url + self.organization = organization + self.format_type = format_type + self.temperature = temperature + self.max_workers = max_workers + self._extra_kwargs = kwargs or {} + + if not self.api_key: + raise exceptions.InferenceConfigError('API key not provided for OpenAI.') + + # Initialize the OpenAI client + self._client = openai.OpenAI( + api_key=self.api_key, + base_url=self.base_url, + organization=self.organization, + ) + + super().__init__( + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + ) + + def _process_single_prompt( + self, prompt: str, config: dict + ) -> inference.ScoredOutput: + """Process a single prompt and return a ScoredOutput.""" + try: + # Prepare the system message for structured output + system_message = '' + if self.format_type == data.FormatType.JSON: + system_message = ( + 'You are a helpful assistant that responds in JSON format.' + ) + elif self.format_type == data.FormatType.YAML: + system_message = ( + 'You are a helpful assistant that responds in YAML format.' + ) + + response = self._client.chat.completions.create( + model=self.model_id, + messages=[ + {'role': 'system', 'content': system_message}, + {'role': 'user', 'content': prompt}, + ], + temperature=config.get('temperature', self.temperature), + max_tokens=config.get('max_output_tokens'), + top_p=config.get('top_p'), + n=1, + ) + + # Extract the response text using the v1.x response format + output_text = response.choices[0].message.content + + return inference.ScoredOutput(score=1.0, output=output_text) + + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'OpenAI API error: {str(e)}', original=e + ) from e + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[inference.ScoredOutput]]: + """Runs inference on a list of prompts via OpenAI's API. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params (temperature, top_p, etc.) + + Yields: + Lists of ScoredOutputs. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + if 'max_output_tokens' in kwargs: + config['max_output_tokens'] = kwargs['max_output_tokens'] + if 'top_p' in kwargs: + config['top_p'] = kwargs['top_p'] + + # Use parallel processing for batches larger than 1 + if len(batch_prompts) > 1 and self.max_workers > 1: + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(self.max_workers, len(batch_prompts)) + ) as executor: + future_to_index = { + executor.submit( + self._process_single_prompt, prompt, config.copy() + ): i + for i, prompt in enumerate(batch_prompts) + } + + results: list[inference.ScoredOutput | None] = [None] * len( + batch_prompts + ) + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'Parallel inference error: {str(e)}', original=e + ) from e + + for result in results: + if result is None: + raise exceptions.InferenceRuntimeError( + 'Failed to process one or more prompts' + ) + yield [result] + else: + # Sequential processing for single prompt or worker + for prompt in batch_prompts: + result = self._process_single_prompt(prompt, config.copy()) + yield [result] diff --git a/langextract/providers/registry.py b/langextract/providers/registry.py new file mode 100644 index 00000000..e4a7a80c --- /dev/null +++ b/langextract/providers/registry.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime registry that maps model-ID patterns to provider classes. + +This module provides a lazy registration system for LLM providers, allowing +providers to be registered without importing their dependencies until needed. +""" +# pylint: disable=cyclic-import + +from __future__ import annotations + +import dataclasses +import functools +import importlib +import re +import typing + +from absl import logging + +from langextract import inference + + +@dataclasses.dataclass(frozen=True, slots=True) +class _Entry: + """Registry entry for a provider.""" + + patterns: tuple[re.Pattern[str], ...] + loader: typing.Callable[[], type[inference.BaseLanguageModel]] + priority: int + + +_ENTRIES: list[_Entry] = [] + + +def register_lazy( + *patterns: str | re.Pattern[str], target: str, priority: int = 0 +) -> None: + """Register a provider lazily using string import path. + + Args: + *patterns: One or more regex patterns to match model IDs. + target: Import path in format "module.path:ClassName". + priority: Priority for resolution (higher wins on conflicts). + """ + compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns) + + def _loader() -> type[inference.BaseLanguageModel]: + module_path, class_name = target.rsplit(":", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + _ENTRIES.append(_Entry(patterns=compiled, loader=_loader, priority=priority)) + logging.debug( + "Registered provider with patterns %s at priority %d", + [p.pattern for p in compiled], + priority, + ) + + +def register( + *patterns: str | re.Pattern[str], priority: int = 0 +) -> typing.Callable[ + [type[inference.BaseLanguageModel]], type[inference.BaseLanguageModel] +]: + """Decorator to register a provider class directly. + + Args: + *patterns: One or more regex patterns to match model IDs. + priority: Priority for resolution (higher wins on conflicts). + + Returns: + Decorator function that registers the class. + """ + compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns) + + def _decorator( + cls: type[inference.BaseLanguageModel], + ) -> type[inference.BaseLanguageModel]: + def _loader() -> type[inference.BaseLanguageModel]: + return cls + + _ENTRIES.append( + _Entry(patterns=compiled, loader=_loader, priority=priority) + ) + logging.debug( + "Registered %s with patterns %s at priority %d", + cls.__name__, + [p.pattern for p in compiled], + priority, + ) + return cls + + return _decorator + + +@functools.lru_cache(maxsize=128) +def resolve(model_id: str) -> type[inference.BaseLanguageModel]: + """Resolve a model ID to a provider class. + + Args: + model_id: The model identifier to resolve. + + Returns: + The provider class that handles this model ID. + + Raises: + ValueError: If no provider is registered for the model ID. + """ + # pylint: disable=import-outside-toplevel + from langextract import providers + + providers.load_plugins_once() + + sorted_entries = sorted(_ENTRIES, key=lambda e: e.priority, reverse=True) + + for entry in sorted_entries: + if any(pattern.search(model_id) for pattern in entry.patterns): + return entry.loader() + + raise ValueError( + f"No provider registered for model_id={model_id!r}. Available patterns:" + f" {[str(p.pattern) for e in _ENTRIES for p in e.patterns]}" + ) + + +@functools.lru_cache(maxsize=128) +def resolve_provider(provider_name: str) -> type[inference.BaseLanguageModel]: + """Resolve a provider name to a provider class. + + This allows explicit provider selection by name or class name. + + Args: + provider_name: The provider name (e.g., "gemini", "openai") or + class name (e.g., "GeminiLanguageModel"). + + Returns: + The provider class. + + Raises: + ValueError: If no provider matches the name. + """ + # pylint: disable=import-outside-toplevel + from langextract import providers + + providers.load_plugins_once() + + for entry in _ENTRIES: + for pattern in entry.patterns: + if pattern.pattern == f"^{re.escape(provider_name)}$": + return entry.loader() + + for entry in _ENTRIES: + try: + provider_class = entry.loader() + class_name = provider_class.__name__ + if provider_name.lower() in class_name.lower(): + return provider_class + except (ImportError, AttributeError): + continue + + try: + pattern = re.compile(f"^{provider_name}$", re.IGNORECASE) + for entry in _ENTRIES: + for entry_pattern in entry.patterns: + if pattern.pattern == entry_pattern.pattern: + return entry.loader() + except re.error: + pass + + raise ValueError( + f"No provider found matching: {provider_name!r}. " + "Available providers can be listed with list_providers()" + ) + + +def clear() -> None: + """Clear all registered providers. Mainly for testing.""" + global _ENTRIES # pylint: disable=global-statement + _ENTRIES = [] + resolve.cache_clear() + + +def list_providers() -> list[tuple[tuple[str, ...], int]]: + """List all registered providers with their patterns and priorities. + + Returns: + List of (patterns, priority) tuples for debugging. + """ + return [ + (tuple(p.pattern for p in entry.patterns), entry.priority) + for entry in _ENTRIES + ] + + +def list_entries() -> list[tuple[list[str], int]]: + """List all registered patterns and priorities. Mainly for debugging. + + Returns: + List of (patterns, priority) tuples. + """ + return [([p.pattern for p in e.patterns], e.priority) for e in _ENTRIES] diff --git a/pyproject.toml b/pyproject.toml index 027c31e5..47c8a707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ requires = ["setuptools>=67.0.0", "wheel"] build-backend = "setuptools.build_meta" + [project] name = "langextract" version = "1.0.5" @@ -35,7 +36,6 @@ dependencies = [ "ml-collections>=0.1.0", "more-itertools>=8.0.0", "numpy>=1.20.0", - "openai>=1.50.0", "pandas>=1.3.0", "pydantic>=1.8.0", "python-dotenv>=0.19.0", @@ -52,6 +52,8 @@ dependencies = [ "Bug Tracker" = "https://github.com/google/langextract/issues" [project.optional-dependencies] +openai = ["openai>=1.50.0"] +all = ["openai>=1.50.0"] dev = [ "pyink~=24.3.0", "isort>=5.13.0", @@ -69,7 +71,7 @@ notebook = [ ] [tool.setuptools] -packages = ["langextract"] +packages = ["langextract", "langextract.providers"] include-package-data = false [tool.setuptools.exclude-package-data] diff --git a/tests/factory_test.py b/tests/factory_test.py new file mode 100644 index 00000000..84bcecab --- /dev/null +++ b/tests/factory_test.py @@ -0,0 +1,315 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the factory module.""" + +import os +from unittest import mock + +from absl.testing import absltest + +from langextract import exceptions +from langextract import factory +from langextract import inference +from langextract.providers import registry + + +class FakeGeminiProvider(inference.BaseLanguageModel): + """Fake Gemini provider for testing.""" + + def __init__(self, model_id, api_key=None, **kwargs): + self.model_id = model_id + self.api_key = api_key + self.kwargs = kwargs + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="gemini")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + +class FakeOpenAIProvider(inference.BaseLanguageModel): + """Fake OpenAI provider for testing.""" + + def __init__(self, model_id, api_key=None, **kwargs): + if not api_key: + raise ValueError("API key required") + self.model_id = model_id + self.api_key = api_key + self.kwargs = kwargs + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="openai")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + +class FactoryTest(absltest.TestCase): + + def setUp(self): + super().setUp() + registry.clear() + import langextract.providers as providers_module # pylint: disable=import-outside-toplevel + + providers_module._PLUGINS_LOADED = True + registry.register_lazy( + r"^gemini", target="factory_test:FakeGeminiProvider", priority=100 + ) + registry.register_lazy( + r"^gpt", r"^o1", target="factory_test:FakeOpenAIProvider", priority=100 + ) + + def tearDown(self): + super().tearDown() + registry.clear() + import langextract.providers as providers_module # pylint: disable=import-outside-toplevel + + providers_module._PLUGINS_LOADED = False + + def test_create_model_basic(self): + """Test basic model creation.""" + config = factory.ModelConfig( + model_id="gemini-pro", provider_kwargs={"api_key": "test-key"} + ) + + model = factory.create_model(config) + self.assertIsInstance(model, FakeGeminiProvider) + self.assertEqual(model.model_id, "gemini-pro") + self.assertEqual(model.api_key, "test-key") + + def test_create_model_from_id(self): + """Test convenience function for creating model from ID.""" + model = factory.create_model_from_id("gemini-flash", api_key="test-key") + + self.assertIsInstance(model, FakeGeminiProvider) + self.assertEqual(model.model_id, "gemini-flash") + self.assertEqual(model.api_key, "test-key") + + @mock.patch.dict(os.environ, {"GEMINI_API_KEY": "env-gemini-key"}) + def test_uses_gemini_api_key_from_environment(self): + """Factory should use GEMINI_API_KEY from environment for Gemini models.""" + config = factory.ModelConfig(model_id="gemini-pro") + + model = factory.create_model(config) + self.assertEqual(model.api_key, "env-gemini-key") + + @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "env-openai-key"}) + def test_uses_openai_api_key_from_environment(self): + """Factory should use OPENAI_API_KEY from environment for OpenAI models.""" + config = factory.ModelConfig(model_id="gpt-4") + + model = factory.create_model(config) + self.assertEqual(model.api_key, "env-openai-key") + + @mock.patch.dict( + os.environ, {"LANGEXTRACT_API_KEY": "env-langextract-key"}, clear=True + ) + def test_falls_back_to_langextract_api_key_when_provider_key_missing(self): + """Factory uses LANGEXTRACT_API_KEY when provider-specific key is missing.""" + config = factory.ModelConfig(model_id="gemini-pro") + + model = factory.create_model(config) + self.assertEqual(model.api_key, "env-langextract-key") + + @mock.patch.dict( + os.environ, + { + "GEMINI_API_KEY": "gemini-key", + "LANGEXTRACT_API_KEY": "langextract-key", + }, + ) + def test_provider_specific_key_takes_priority_over_langextract_key(self): + """Factory prefers provider-specific API key over LANGEXTRACT_API_KEY.""" + config = factory.ModelConfig(model_id="gemini-pro") + + model = factory.create_model(config) + self.assertEqual(model.api_key, "gemini-key") + + def test_explicit_kwargs_override_env(self): + """Test that explicit kwargs override environment variables.""" + with mock.patch.dict(os.environ, {"GEMINI_API_KEY": "env-key"}): + config = factory.ModelConfig( + model_id="gemini-pro", provider_kwargs={"api_key": "explicit-key"} + ) + + model = factory.create_model(config) + self.assertEqual(model.api_key, "explicit-key") + + @mock.patch.dict(os.environ, {}, clear=True) + def test_wraps_provider_initialization_error_in_inference_config_error(self): + """Factory should wrap provider errors in InferenceConfigError.""" + config = factory.ModelConfig(model_id="gpt-4") + + with self.assertRaises(exceptions.InferenceConfigError) as cm: + factory.create_model(config) + + self.assertIn("Failed to create provider", str(cm.exception)) + self.assertIn("API key required", str(cm.exception)) + + def test_raises_error_when_no_provider_matches_model_id(self): + """Factory should raise ValueError for unregistered model IDs.""" + config = factory.ModelConfig(model_id="unknown-model") + + with self.assertRaises(ValueError) as cm: + factory.create_model(config) + + self.assertIn("No provider registered", str(cm.exception)) + + def test_additional_kwargs_passed_through(self): + """Test that additional kwargs are passed to provider.""" + config = factory.ModelConfig( + model_id="gemini-pro", + provider_kwargs={ + "api_key": "test-key", + "temperature": 0.5, + "max_tokens": 100, + "custom_param": "value", + }, + ) + + model = factory.create_model(config) + self.assertEqual(model.kwargs["temperature"], 0.5) + self.assertEqual(model.kwargs["max_tokens"], 100) + self.assertEqual(model.kwargs["custom_param"], "value") + + @mock.patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://custom:11434"}) + def test_ollama_uses_base_url_from_environment(self): + """Factory should use OLLAMA_BASE_URL from environment for Ollama models.""" + + @registry.register(r"^ollama") + class FakeOllamaProvider(inference.BaseLanguageModel): # pylint: disable=unused-variable + + def __init__(self, model_id, base_url=None, **kwargs): + self.model_id = model_id + self.base_url = base_url + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="ollama")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + config = factory.ModelConfig(model_id="ollama/llama2") + model = factory.create_model(config) + + self.assertEqual(model.base_url, "http://custom:11434") + + def test_model_config_fields_are_immutable(self): + """ModelConfig fields should not be modifiable after creation.""" + config = factory.ModelConfig( + model_id="gemini-pro", provider_kwargs={"api_key": "test"} + ) + + with self.assertRaises(AttributeError): + config.model_id = "different" + + def test_model_config_allows_dict_contents_modification(self): + """ModelConfig allows modification of dict contents (not deeply frozen).""" + config = factory.ModelConfig( + model_id="gemini-pro", provider_kwargs={"api_key": "test"} + ) + + config.provider_kwargs["new_key"] = "value" + + self.assertEqual(config.provider_kwargs["new_key"], "value") + + def test_uses_highest_priority_provider_when_multiple_match(self): + """Factory uses highest priority provider when multiple patterns match.""" + + @registry.register(r"^gemini", priority=90) + class AnotherGeminiProvider(inference.BaseLanguageModel): # pylint: disable=unused-variable + + def __init__(self, model_id=None, **kwargs): + self.model_id = model_id or "default-model" + self.kwargs = kwargs + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="another")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + config = factory.ModelConfig(model_id="gemini-pro") + model = factory.create_model(config) + + self.assertIsInstance(model, FakeGeminiProvider) # Priority 100 wins + + def test_explicit_provider_overrides_pattern_matching(self): + """Factory should use explicit provider even when pattern doesn't match.""" + + @registry.register(r"^another", priority=90) + class AnotherProvider(inference.BaseLanguageModel): + + def __init__(self, model_id=None, **kwargs): + self.model_id = model_id or "default-model" + self.kwargs = kwargs + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="another")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + config = factory.ModelConfig( + model_id="gemini-pro", provider="AnotherProvider" + ) + model = factory.create_model(config) + + self.assertIsInstance(model, AnotherProvider) + self.assertEqual(model.model_id, "gemini-pro") + + def test_provider_without_model_id_uses_provider_default(self): + """Factory should use provider's default model_id when none specified.""" + + @registry.register(r"^default-provider$", priority=50) + class DefaultProvider(inference.BaseLanguageModel): + + def __init__(self, model_id="default-model", **kwargs): + self.model_id = model_id + self.kwargs = kwargs + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="default")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + config = factory.ModelConfig(provider="DefaultProvider") + model = factory.create_model(config) + + self.assertIsInstance(model, DefaultProvider) + self.assertEqual(model.model_id, "default-model") + + def test_raises_error_when_neither_model_id_nor_provider_specified(self): + """Factory raises ValueError when config has neither model_id nor provider.""" + config = factory.ModelConfig() + + with self.assertRaises(ValueError) as cm: + factory.create_model(config) + + self.assertIn( + "Either model_id or provider must be specified", str(cm.exception) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/inference_test.py b/tests/inference_test.py index 2bb91f13..22f3d8d7 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -23,7 +23,7 @@ class TestOllamaLanguageModel(absltest.TestCase): - @mock.patch.object(inference.OllamaLanguageModel, "_ollama_query") + @mock.patch("langextract.providers.ollama.OllamaLanguageModel._ollama_query") def test_ollama_infer(self, mock_ollama_query): # Actuall full gemma2 response using Ollama. @@ -129,6 +129,7 @@ def test_openai_infer_with_parameters( results = list(model.infer(batch_prompts)) # Verify API was called correctly + # Note: The new implementation adds a system message for JSON format mock_client.chat.completions.create.assert_called_once_with( model="gpt-4o-mini", messages=[ diff --git a/tests/init_test.py b/tests/init_test.py index d79a07f4..23c138d5 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -24,21 +24,23 @@ from langextract import prompting from langextract import schema import langextract as lx +from langextract.providers import gemini class InitTest(absltest.TestCase): """Test cases for the main package functions.""" @mock.patch.object(schema.GeminiSchema, "from_examples", autospec=True) - @mock.patch.object(inference.GeminiLanguageModel, "infer", autospec=True) + @mock.patch("langextract.factory.create_model") def test_lang_extract_as_lx_extract( - self, mock_gemini_model_infer, mock_gemini_schema + self, mock_create_model, mock_gemini_schema ): input_text = "Patient takes Aspirin 100mg every morning." - # Mock the language model's response - mock_gemini_model_infer.return_value = [[ + # Create a mock model instance + mock_model = mock.MagicMock() + mock_model.infer.return_value = [[ inference.ScoredOutput( output=textwrap.dedent("""\ ```json @@ -64,6 +66,9 @@ def test_lang_extract_as_lx_extract( ) ]] + # Make factory return our mock model + mock_create_model.return_value = mock_model + mock_gemini_schema.return_value = None # No live endpoint to process schema expected_result = data.AnnotatedDocument( @@ -130,14 +135,8 @@ def test_lang_extract_as_lx_extract( ) mock_gemini_schema.assert_not_called() - mock_gemini_model_infer.assert_called_once_with( - inference.GeminiLanguageModel( - model_id="gemini-2.5-flash", - api_key="some_api_key", - gemini_schema=None, - format_type=data.FormatType.JSON, - temperature=0.5, - ), + mock_create_model.assert_called_once() + mock_model.infer.assert_called_once_with( batch_prompts=[prompt_generator.render(input_text)], ) diff --git a/tests/registry_test.py b/tests/registry_test.py new file mode 100644 index 00000000..147a264c --- /dev/null +++ b/tests/registry_test.py @@ -0,0 +1,197 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the provider registry module.""" + +import re +from unittest import mock + +from absl.testing import absltest + +from langextract import inference +from langextract.providers import registry + + +class FakeProvider(inference.BaseLanguageModel): + """Fake provider for testing.""" + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="test")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + +class AnotherFakeProvider(inference.BaseLanguageModel): + """Another fake provider for testing.""" + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="another")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + +class RegistryTest(absltest.TestCase): + + def setUp(self): + super().setUp() + registry.clear() + + def tearDown(self): + super().tearDown() + registry.clear() + + def test_register_decorator(self): + """Test registering a provider using the decorator.""" + + @registry.register(r"^test-model") + class TestProvider(FakeProvider): + pass + + resolved = registry.resolve("test-model-v1") + self.assertEqual(resolved, TestProvider) + + def test_register_lazy(self): + """Test lazy registration with string target.""" + registry.register_lazy(r"^fake-model", target="registry_test:FakeProvider") + + resolved = registry.resolve("fake-model-v2") + self.assertEqual(resolved, FakeProvider) + + def test_multiple_patterns(self): + """Test registering multiple patterns for one provider.""" + registry.register_lazy( + r"^gemini", r"^palm", target="registry_test:FakeProvider" + ) + + self.assertEqual(registry.resolve("gemini-pro"), FakeProvider) + self.assertEqual(registry.resolve("palm-2"), FakeProvider) + + def test_priority_resolution(self): + """Test that higher priority wins on conflicts.""" + registry.register_lazy( + r"^model", target="registry_test:FakeProvider", priority=0 + ) + registry.register_lazy( + r"^model", target="registry_test:AnotherFakeProvider", priority=10 + ) + + resolved = registry.resolve("model-v1") + self.assertEqual(resolved, AnotherFakeProvider) + + def test_no_provider_registered(self): + """Test error when no provider matches.""" + with self.assertRaisesRegex( + ValueError, "No provider registered for model_id='unknown-model'" + ): + registry.resolve("unknown-model") + + def test_caching(self): + """Test that resolve results are cached.""" + registry.register_lazy(r"^cached", target="registry_test:FakeProvider") + + # First call + result1 = registry.resolve("cached-model") + # Second call should return cached result + result2 = registry.resolve("cached-model") + + self.assertIs(result1, result2) + + def test_clear_registry(self): + """Test clearing the registry.""" + registry.register_lazy(r"^temp", target="registry_test:FakeProvider") + + # Should resolve before clear + resolved = registry.resolve("temp-model") + self.assertEqual(resolved, FakeProvider) + + # Clear registry + registry.clear() + + # Should fail after clear + with self.assertRaises(ValueError): + registry.resolve("temp-model") + + def test_list_entries(self): + """Test listing registered entries.""" + registry.register_lazy(r"^test1", target="fake:Target1", priority=5) + registry.register_lazy( + r"^test2", r"^test3", target="fake:Target2", priority=10 + ) + + entries = registry.list_entries() + self.assertEqual(len(entries), 2) + + # Check first entry + patterns1, priority1 = entries[0] + self.assertEqual(patterns1, ["^test1"]) + self.assertEqual(priority1, 5) + + # Check second entry + patterns2, priority2 = entries[1] + self.assertEqual(set(patterns2), {"^test2", "^test3"}) + self.assertEqual(priority2, 10) + + def test_lazy_loading_defers_import(self): + """Test that lazy registration doesn't import until resolve.""" + # Register with a module that would fail if imported + registry.register_lazy(r"^lazy", target="non.existent.module:Provider") + + # Registration should succeed without importing + entries = registry.list_entries() + self.assertTrue(any("^lazy" in patterns for patterns, _ in entries)) + + # Only on resolve should it try to import and fail + with self.assertRaises(ModuleNotFoundError): + registry.resolve("lazy-model") + + def test_regex_pattern_objects(self): + """Test using pre-compiled regex patterns.""" + pattern = re.compile(r"^custom-\d+") + + @registry.register(pattern) + class CustomProvider(FakeProvider): + pass + + self.assertEqual(registry.resolve("custom-123"), CustomProvider) + + # Should not match without digits + with self.assertRaises(ValueError): + registry.resolve("custom-abc") + + def test_resolve_provider_by_name(self): + """Test resolving provider by exact name.""" + + @registry.register(r"^test-model", r"^TestProvider$") + class TestProvider(FakeProvider): + pass + + # Resolve by exact class name pattern + provider = registry.resolve_provider("TestProvider") + self.assertEqual(provider, TestProvider) + + # Resolve by partial name match + provider = registry.resolve_provider("test") + self.assertEqual(provider, TestProvider) + + def test_resolve_provider_not_found(self): + """Test resolve_provider raises for unknown provider.""" + with self.assertRaises(ValueError) as cm: + registry.resolve_provider("UnknownProvider") + self.assertIn("No provider found matching", str(cm.exception)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_live_api.py b/tests/test_live_api.py index 8d9801eb..2752e204 100644 --- a/tests/test_live_api.py +++ b/tests/test_live_api.py @@ -29,7 +29,6 @@ import pytest import langextract as lx -from langextract.inference import OpenAILanguageModel load_dotenv() @@ -387,6 +386,37 @@ def test_multilingual_medication_extraction(self): ), "No medication entities found in Japanese text" assert_valid_char_intervals(self, result) + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + def test_explicit_provider_gemini(self): + """Test using explicit provider with Gemini.""" + # Test using provider class name + config = lx.factory.ModelConfig( + model_id="gemini-2.5-flash", + provider="GeminiLanguageModel", + provider_kwargs={ + "api_key": GEMINI_API_KEY, + "temperature": 0.0, + }, + ) + + model = lx.factory.create_model(config) + self.assertEqual(model.__class__.__name__, "GeminiLanguageModel") + self.assertEqual(model.model_id, "gemini-2.5-flash") + + # Test using partial name match + config2 = lx.factory.ModelConfig( + model_id="gemini-2.5-flash", + provider="gemini", # Should match GeminiLanguageModel + provider_kwargs={ + "api_key": GEMINI_API_KEY, + }, + ) + + model2 = lx.factory.create_model(config2) + self.assertEqual(model2.__class__.__name__, "GeminiLanguageModel") + @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=2) @@ -469,7 +499,6 @@ def test_medication_extraction(self): text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=OpenAILanguageModel, model_id=DEFAULT_OPENAI_MODEL, api_key=OPENAI_API_KEY, fence_output=True, @@ -519,6 +548,41 @@ def test_medication_extraction(self): f"No PO/oral route found in: {route_texts}", ) + @skip_if_no_openai + @live_api + @retry_on_transient_errors(max_retries=2) + def test_explicit_provider_selection(self): + """Test using explicit provider parameter for disambiguation.""" + # Test with explicit model_id and provider + config = lx.factory.ModelConfig( + model_id=DEFAULT_OPENAI_MODEL, + provider="OpenAILanguageModel", # Explicit provider selection + provider_kwargs={ + "api_key": OPENAI_API_KEY, + "fence_output": True, + "temperature": 0.0, + }, + ) + + model = lx.factory.create_model(config) + + # Verify we got the right provider + self.assertEqual(model.__class__.__name__, "OpenAILanguageModel") + self.assertEqual(model.model_id, DEFAULT_OPENAI_MODEL) + + # Also test using provider without model_id (uses default) + config_default = lx.factory.ModelConfig( + provider="OpenAILanguageModel", + provider_kwargs={ + "api_key": OPENAI_API_KEY, + }, + ) + + model_default = lx.factory.create_model(config_default) + self.assertEqual(model_default.__class__.__name__, "OpenAILanguageModel") + # Should use the default model_id from the provider + self.assertEqual(model_default.model_id, "gpt-4o-mini") + @skip_if_no_openai @live_api @retry_on_transient_errors(max_retries=2) @@ -544,7 +608,6 @@ def test_medication_relationship_extraction(self): text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=OpenAILanguageModel, model_id=DEFAULT_OPENAI_MODEL, api_key=OPENAI_API_KEY, fence_output=True, diff --git a/tests/test_ollama_integration.py b/tests/test_ollama_integration.py index 5ab4397d..b087c303 100644 --- a/tests/test_ollama_integration.py +++ b/tests/test_ollama_integration.py @@ -60,7 +60,6 @@ def test_ollama_extraction(): text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=lx.inference.OllamaLanguageModel, model_id=model_id, model_url="http://localhost:11434", temperature=0.3, diff --git a/tox.ini b/tox.ini index 7abd98a0..70cdcf12 100644 --- a/tox.ini +++ b/tox.ini @@ -20,7 +20,7 @@ skip_missing_interpreters = True setenv = PYTHONWARNINGS = ignore deps = - .[dev,test] + .[openai,dev,test] commands = pytest -ra -m "not live_api" @@ -51,12 +51,14 @@ passenv = GEMINI_API_KEY LANGEXTRACT_API_KEY OPENAI_API_KEY -deps = {[testenv]deps} +deps = .[all,dev,test] commands = pytest tests/test_live_api.py -v -m live_api --maxfail=1 [testenv:ollama-integration] basepython = python3.11 -deps = {[testenv]deps} +deps = + .[openai,dev,test] + requests>=2.25.0 commands = pytest tests/test_ollama_integration.py -v --tb=short From c8aa788123dbb98e2adb27e2916978e6cda4b26b Mon Sep 17 00:00:00 2001 From: "goelak@google.com" Date: Fri, 8 Aug 2025 08:13:16 -0400 Subject: [PATCH 04/17] Update provider documentation --- langextract/providers/README.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/langextract/providers/README.md b/langextract/providers/README.md index 1395fb38..66e6899d 100644 --- a/langextract/providers/README.md +++ b/langextract/providers/README.md @@ -349,14 +349,6 @@ The factory automatically resolves API keys from environment: 4. **Explicit Control**: Users can force specific providers when needed 5. **Pattern Priority**: All patterns have equal priority (0) by default -## Migration Path for OpenAI - -Currently, OpenAI is an optional built-in provider. Future plan: -1. Move to external plugin package (`langextract-openai`) -2. Users install via `pip install langextract-openai` -3. Usage remains exactly the same -4. Benefits: Cleaner dependencies, better modularity - ## Common Issues ### Provider Not Found From f069d6fd07c8e71714e031e8fa0cce3da69ddaab Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Sat, 9 Aug 2025 01:09:11 -0400 Subject: [PATCH 05/17] Update .gitignore with additional development patterns Add common development files, tools, and temporary file patterns --- .gitignore | 111 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/.gitignore b/.gitignore index fc93e588..43ce0213 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,117 @@ docs/_build/ .idea/ .vscode/ *.swp +*.swo +*~ +.*.swp +.*.swo # OS-specific .DS_Store +Thumbs.db +ehthumbs.db +Desktop.ini +$RECYCLE.BIN/ +*.cab +*.msi +*.msm +*.msp +*.lnk + +# Development tools & environments +.python-version +.pytype/ +.mypy_cache/ +.dmypy.json +dmypy.json +.pyre/ +.ruff_cache/ +*.sage.py +.hypothesis/ +.scrapy + +# Jupyter Notebooks +.ipynb_checkpoints +*/.ipynb_checkpoints/* +profile_default/ +ipython_config.py + +# Logs and databases +*.log +*.sql +*.sqlite +*.sqlite3 +db.sqlite3 +db.sqlite3-journal +logs/ +*.pid + +# Security and secrets +*.key +*.pem +*.crt +*.csr +.env.local +.env.production +.env.*.local +secrets/ +credentials/ + +# AI tooling +CLAUDE.md +.claude/settings.local.json +.aider.chat.history.* +.aider.input.history +.gemini/ +GEMINI.md + +# Package managers +pip-log.txt +pip-delete-this-directory.txt +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* +package-lock.json +yarn.lock +pnpm-lock.yaml + +# Local development +local_settings.py +instance/ +.webassets-cache +.sass-cache/ +*.css.map +*.js.map +.dev/ + +# Temporary files +tmp/ +temp/ +cache/ +*.tmp +*.bak +*.backup +*.orig +.~lock.*# + +# Archives +*.tar +*.tar.gz +*.zip +*.rar +*.7z +*.dmg +*.iso +*.jar + +# Media files +*.mp4 +*.avi +*.mov +*.wmv +*.flv +*.mp3 +*.wav +*.ogg From 0c08fd18892147b82d08d0060c7600c391fcb323 Mon Sep 17 00:00:00 2001 From: "goelak@google.com" Date: Sat, 9 Aug 2025 01:47:03 -0400 Subject: [PATCH 06/17] Update custom provider example to clarify planned model passing feature - Show current approach using factory.create_model() - Add note that direct model passing to extract() is coming soon - Keep planned API as commented code for reference --- examples/custom_provider_plugin/README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/custom_provider_plugin/README.md b/examples/custom_provider_plugin/README.md index 6aaf3795..b8cd299b 100644 --- a/examples/custom_provider_plugin/README.md +++ b/examples/custom_provider_plugin/README.md @@ -61,6 +61,7 @@ Since this example registers the same pattern as the default Gemini provider, yo ```python import langextract as lx +# Create a configured model with explicit provider selection config = lx.factory.ModelConfig( model_id="gemini-2.5-flash", provider="CustomGeminiProvider", @@ -68,11 +69,22 @@ config = lx.factory.ModelConfig( ) model = lx.factory.create_model(config) +# Note: Passing model directly to extract() is coming soon. +# For now, use the model's infer() method directly or pass parameters individually: result = lx.extract( text_or_documents="Your text here", - model=model, - prompt_description="Extract key information" + model_id="gemini-2.5-flash", + api_key="your-api-key", + prompt_description="Extract key information", + examples=[...] ) + +# Coming soon: Direct model passing +# result = lx.extract( +# text_or_documents="Your text here", +# model=model, # Planned feature +# prompt_description="Extract key information" +# ) ``` ## Creating Your Own Provider From 1a25621c492aa9d5c39bb8d96a0565b44b63cf54 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Sun, 10 Aug 2025 21:00:31 -0400 Subject: [PATCH 07/17] Fix lazy loading for provider pattern registration (#113) Ensure providers are loaded before pattern matching to prevent API key errors when using local models. Optimize to skip loading when provider is explicitly specified. --- langextract/factory.py | 4 ++++ langextract/providers/__init__.py | 37 ++++++++++++++++++++++--------- langextract/providers/registry.py | 2 ++ tests/factory_test.py | 28 +++++++++++++++++++++++ 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/langextract/factory.py b/langextract/factory.py index 3f3bde65..8e6fc7e8 100644 --- a/langextract/factory.py +++ b/langextract/factory.py @@ -27,6 +27,7 @@ from langextract import exceptions from langextract import inference +from langextract import providers from langextract.providers import registry @@ -107,6 +108,9 @@ def create_model(config: ModelConfig) -> inference.BaseLanguageModel: if config.provider: provider_class = registry.resolve_provider(config.provider) else: + # Load providers before pattern matching + providers.load_builtins_once() + providers.load_plugins_once() provider_class = registry.resolve(config.model_id) except (ModuleNotFoundError, ImportError) as e: raise exceptions.InferenceConfigError( diff --git a/langextract/providers/__init__.py b/langextract/providers/__init__.py index d4b65a81..c0d46003 100644 --- a/langextract/providers/__init__.py +++ b/langextract/providers/__init__.py @@ -26,8 +26,32 @@ from langextract.providers import registry -# Track whether plugins have been loaded +# Track provider loading for lazy initialization _PLUGINS_LOADED = False +_BUILTINS_LOADED = False + + +def load_builtins_once() -> None: + """Load built-in providers to register their patterns. + + Idempotent function that ensures provider patterns are available + for model resolution. + """ + global _BUILTINS_LOADED # pylint: disable=global-statement + if _BUILTINS_LOADED: + return + + # pylint: disable=import-outside-toplevel + from langextract.providers import gemini # noqa: F401 + from langextract.providers import ollama # noqa: F401 + + try: + from langextract.providers import openai # noqa: F401 + except ImportError: + logging.debug("OpenAI provider not available (optional dependency)") + # pylint: enable=import-outside-toplevel + + _BUILTINS_LOADED = True def load_plugins_once() -> None: @@ -73,13 +97,4 @@ def load_plugins_once() -> None: ) -# pylint: disable=wrong-import-position -from langextract.providers import gemini # noqa: F401 -from langextract.providers import ollama # noqa: F401 - -try: - from langextract.providers import openai # noqa: F401 -except ImportError: - pass - -__all__ = ["registry", "load_plugins_once"] +__all__ = ["registry", "load_plugins_once", "load_builtins_once"] diff --git a/langextract/providers/registry.py b/langextract/providers/registry.py index e4a7a80c..1044a25d 100644 --- a/langextract/providers/registry.py +++ b/langextract/providers/registry.py @@ -121,6 +121,7 @@ def resolve(model_id: str) -> type[inference.BaseLanguageModel]: # pylint: disable=import-outside-toplevel from langextract import providers + providers.load_builtins_once() providers.load_plugins_once() sorted_entries = sorted(_ENTRIES, key=lambda e: e.priority, reverse=True) @@ -154,6 +155,7 @@ class name (e.g., "GeminiLanguageModel"). # pylint: disable=import-outside-toplevel from langextract import providers + providers.load_builtins_once() providers.load_plugins_once() for entry in _ENTRIES: diff --git a/tests/factory_test.py b/tests/factory_test.py index 84bcecab..7317d1c9 100644 --- a/tests/factory_test.py +++ b/tests/factory_test.py @@ -210,6 +210,34 @@ def infer_batch(self, prompts, batch_size=32): self.assertEqual(model.base_url, "http://custom:11434") + def test_ollama_models_select_without_api_keys(self): + """Test that Ollama models resolve without API keys or explicit type.""" + + @registry.register( + r"^llama", r"^gemma", r"^mistral", r"^qwen", priority=100 + ) + class FakeOllamaProvider(inference.BaseLanguageModel): + + def __init__(self, model_id, **kwargs): + self.model_id = model_id + super().__init__() + + def infer(self, batch_prompts, **kwargs): + return [[inference.ScoredOutput(score=1.0, output="test")]] + + def infer_batch(self, prompts, batch_size=32): + return self.infer(prompts) + + test_models = ["llama3", "gemma2:2b", "mistral:7b", "qwen3:0.6b"] + + for model_id in test_models: + with self.subTest(model_id=model_id): + with mock.patch.dict(os.environ, {}, clear=True): + config = factory.ModelConfig(model_id=model_id) + model = factory.create_model(config) + self.assertIsInstance(model, FakeOllamaProvider) + self.assertEqual(model.model_id, model_id) + def test_model_config_fields_are_immutable(self): """ModelConfig fields should not be modifiable after creation.""" config = factory.ModelConfig( From a209903a7721c6c8d29879409a41333436f35940 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Sun, 10 Aug 2025 22:08:59 -0400 Subject: [PATCH 08/17] Improve PR validation workflow based on expert review - Add proper permissions (issues: write for comments) - Skip draft PRs to avoid noise - Prevent duplicate comments with hidden marker - Search both title and body for issue links - Support all keyword variants and cross-repo references - Count unique users for reactions, not total count - Include 'write' permission for maintainer override - Add concurrency control for rapid edits - Handle cross-repo issues gracefully --- .github/workflows/check-linked-issue.yml | 172 ++++++++++++++--------- 1 file changed, 109 insertions(+), 63 deletions(-) diff --git a/.github/workflows/check-linked-issue.yml b/.github/workflows/check-linked-issue.yml index 916c338a..d231c4f5 100644 --- a/.github/workflows/check-linked-issue.yml +++ b/.github/workflows/check-linked-issue.yml @@ -2,89 +2,135 @@ name: Require linked issue with community support on: pull_request_target: - types: [opened, edited, synchronize, reopened] - workflow_dispatch: + types: [opened, edited, synchronize, reopened, ready_for_review] permissions: contents: read - issues: read + issues: write pull-requests: write +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: enforce: - if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false + if: github.event_name == 'pull_request_target' && !github.event.pull_request.draft runs-on: ubuntu-latest steps: - - name: Verify linked issue - if: github.event_name == 'pull_request_target' - uses: nearform-actions/github-action-check-linked-issues@v1.2.7 + - name: Check linked issue and community support + uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} - comment: true - exclude-branches: main - custom-body: | - No linked issues found. Please add "Fixes #" to your pull request description. + script: | + // Strip code blocks and inline code to avoid false matches + const stripCode = txt => + txt.replace(/```[\s\S]*?```/g, '').replace(/`[^`]*`/g, ''); - Per our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md#pull-request-guidelines), all PRs must: - - Reference an issue with "Fixes #123" or "Closes #123" - - The linked issue should have 5+ 👍 reactions - - Include discussion demonstrating the importance of the change + // Combine title + body for comprehensive search + const prText = stripCode(`${context.payload.pull_request.title || ''}\n${context.payload.pull_request.body || ''}`); - Use GitHub automation to close the issue when this PR is merged. + // Issue reference pattern: #123, org/repo#123, or full URL (with http/https and optional www) + const issueRef = String.raw`(?:#(?\d+)|(?[\w.-]+)\/(?[\w.-]+)#(?\d+)|https?:\/\/(?:www\.)?github\.com\/(?[\w.-]+)\/(?[\w.-]+)\/issues\/(?\d+))`; - - name: Check community support - if: github.event_name == 'pull_request_target' - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - // Check if PR author is a maintainer - const prAuthor = context.payload.pull_request.user.login; - const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: context.repo.owner, - repo: context.repo.repo, - username: prAuthor - }); + // Keywords - supporting common variants + const closingRe = new RegExp(String.raw`\b(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)\b\s*:?\s+${issueRef}`, 'gi'); + const referenceRe = new RegExp(String.raw`\b(?:related\s+to|relates\s+to|refs?|part\s+of|addresses|see(?:\s+also)?|depends\s+on|blocked\s+by|supersedes)\b\s*:?\s+${issueRef}`, 'gi'); - const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); + // Gather all matches + const closings = [...prText.matchAll(closingRe)]; + const references = [...prText.matchAll(referenceRe)]; + const first = closings[0] || references[0]; - const body = context.payload.pull_request.body || ''; - const match = body.match(/(?:Fixes|Closes|Resolves)\s+#(\d+)/i); + if (!first) { + // Check for existing comment to avoid duplicates + const MARKER = ''; + const existing = await github.paginate(github.rest.issues.listComments, { + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + per_page: 100, + }); + const alreadyLeft = existing.some(c => c.body && c.body.includes(MARKER)); - if (!match) { - core.setFailed('No linked issue found'); + if (!alreadyLeft) { + const contribUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/blob/main/CONTRIBUTING.md#pull-request-guidelines`; + const commentBody = [ + 'No linked issues found. Please link an issue in your pull request description or title.', + '', + `Per our [Contributing Guidelines](${contribUrl}), all PRs must:`, + '- Reference an issue with one of:', + ' - **Closing keywords**: `Fixes #123`, `Closes #123`, `Resolves #123` (auto-closes on merge in the same repository)', + ' - **Reference keywords**: `Related to #123`, `Refs #123`, `Part of #123`, `See #123` (links without closing)', + '- The linked issue should have 5+ 👍 reactions from unique users (excluding bots and the PR author)', + '- Include discussion demonstrating the importance of the change', + '', + 'You can also use cross-repo references like `owner/repo#123` or full URLs.', + '', + MARKER + ].join('\n'); + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: commentBody + }); + } + + core.setFailed('No linked issue found. Use "Fixes #123" to close an issue or "Related to #123" to reference it.'); return; } - const issueNumber = Number(match[1]); - const { repository } = await github.graphql(` - query($owner: String!, $repo: String!, $number: Int!) { - repository(owner: $owner, name: $repo) { - issue(number: $number) { - reactionGroups { - content - users { - totalCount - } - } - } - } + // Resolve owner/repo/number, defaulting to the current repo + const groups = first.groups || {}; + const owner = groups.o1 || groups.o2 || context.repo.owner; + const repo = groups.r1 || groups.r2 || context.repo.repo; + const issue_number = Number(groups.num || groups.n1 || groups.n2); + + core.info(`Found linked issue: ${owner}/${repo}#${issue_number}`); + + // Check if PR author is a maintainer (include write permission) + let authorPerm = 'none'; + try { + const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + authorPerm = data.permission || 'none'; + } catch (_) { + // User might not have any permissions + } + const isMaintainer = ['admin', 'maintain', 'write'].includes(authorPerm); + + // Count unique users who reacted with 👍 on the linked issue (excluding bots and PR author) + try { + const reactions = await github.paginate(github.rest.reactions.listForIssue, { + owner, + repo, + issue_number, + per_page: 100, + }); + + const prAuthorId = context.payload.pull_request.user.id; + const uniqueThumbs = new Set( + reactions + .filter(r => r.content === '+1' && r.user && r.user.type !== 'Bot' && r.user.id !== prAuthorId) + .map(r => r.user.id) + ).size; + + core.info(`Issue ${owner}/${repo}#${issue_number} has ${uniqueThumbs} unique 👍 reactions`); + + const REQUIRED_THUMBS_UP = 5; + if (uniqueThumbs < REQUIRED_THUMBS_UP && !isMaintainer) { + core.setFailed(`Linked issue ${owner}/${repo}#${issue_number} has only ${uniqueThumbs} 👍 (need ${REQUIRED_THUMBS_UP}). A maintainer can override.`); + return; + } else if (isMaintainer && uniqueThumbs < REQUIRED_THUMBS_UP) { + core.info(`Maintainer ${context.payload.pull_request.user.login} bypassing community support requirement (issue has ${uniqueThumbs} 👍 reactions)`); } - `, { - owner: context.repo.owner, - repo: context.repo.repo, - number: issueNumber - }); - - const reactions = repository.issue.reactionGroups; - const thumbsUp = reactions.find(g => g.content === 'THUMBS_UP')?.users.totalCount || 0; - - core.info(`Issue #${issueNumber} has ${thumbsUp} 👍 reactions`); - - const REQUIRED_THUMBS_UP = 5; - if (thumbsUp < REQUIRED_THUMBS_UP && !isMaintainer) { - core.setFailed(`Issue #${issueNumber} needs at least ${REQUIRED_THUMBS_UP} 👍 reactions (currently has ${thumbsUp})`); - } else if (isMaintainer && thumbsUp < REQUIRED_THUMBS_UP) { - core.info(`Maintainer ${prAuthor} bypassing community support requirement (issue has ${thumbsUp} 👍 reactions)`); + } catch (error) { + core.warning(`Could not check reactions for ${owner}/${repo}#${issue_number}: ${error.message}`); + // Don't fail if we can't access the issue (might be in different repo) } From 898962031946e04cc02c6fbc48ece621476e1e8c Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Sun, 10 Aug 2025 22:45:54 -0400 Subject: [PATCH 09/17] Add tests for provider plugin system (#114) - 6 tests: plugin discovery, loading, idempotency, error handling - Smart CI triggers for integration test on provider changes - New tox environments: plugin-smoke and plugin-integration --- .github/workflows/ci.yaml | 42 +++ langextract/providers/__init__.py | 33 ++- pyproject.toml | 2 + tests/provider_plugin_test.py | 408 ++++++++++++++++++++++++++++++ tox.ini | 19 +- 5 files changed, 494 insertions(+), 10 deletions(-) create mode 100644 tests/provider_plugin_test.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dc166fa9..e3bacb8a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -133,6 +133,48 @@ jobs: fi tox -e live-api + plugin-integration-test: + needs: test + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Detect provider-related changes + id: provider-changes + uses: tj-actions/changed-files@v46 + with: + files: | + langextract/providers/** + langextract/factory.py + langextract/inference.py + tests/provider_plugin_test.py + pyproject.toml + .github/workflows/ci.yaml + + - name: Skip if no provider changes + if: steps.provider-changes.outputs.any_changed == 'false' + run: | + echo "No provider-related changes detected – skipping plugin integration test." + exit 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run plugin smoke test + run: tox -e plugin-smoke + + - name: Run plugin integration test + run: tox -e plugin-integration + ollama-integration-test: needs: test runs-on: ubuntu-latest diff --git a/langextract/providers/__init__.py b/langextract/providers/__init__.py index c0d46003..85631f48 100644 --- a/langextract/providers/__init__.py +++ b/langextract/providers/__init__.py @@ -69,27 +69,42 @@ def load_plugins_once() -> None: _PLUGINS_LOADED = True return - _PLUGINS_LOADED = True - try: entry_points_group = metadata.entry_points(group="langextract.providers") except Exception as exc: logging.debug("No third-party provider entry points found: %s", exc) return + # Set flag after successful entry point query to avoid disabling discovery + # on transient failures during enumeration. + _PLUGINS_LOADED = True + for entry_point in entry_points_group: try: provider = entry_point.load() - + # Validate provider subclasses but don't auto-register - plugins must + # use their own @register decorators to control patterns. if isinstance(provider, type): - registry.register(entry_point.name)(provider) - logging.info( - "Registered third-party provider from entry point: %s", - entry_point.name, - ) + # pylint: disable=import-outside-toplevel + # Late import to avoid circular dependency + from langextract import inference + + if issubclass(provider, inference.BaseLanguageModel): + logging.info( + "Loaded third-party provider class from entry point: %s", + entry_point.name, + ) + else: + logging.warning( + "Entry point %s returned non-provider class %r; ignoring", + entry_point.name, + provider, + ) else: + # Module import triggers decorator execution logging.debug( - "Loaded provider module from entry point: %s", entry_point.name + "Loaded provider module/object from entry point: %s", + entry_point.name, ) except Exception as exc: logging.warning( diff --git a/pyproject.toml b/pyproject.toml index 47c8a707..3831c844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,8 @@ python_functions = "test_*" addopts = "-ra" markers = [ "live_api: marks tests as requiring live API access", + "requires_pip: marks tests that perform pip install/uninstall operations", + "integration: marks integration tests that test multiple components together", ] [tool.pyink] diff --git a/tests/provider_plugin_test.py b/tests/provider_plugin_test.py new file mode 100644 index 00000000..db6bed9c --- /dev/null +++ b/tests/provider_plugin_test.py @@ -0,0 +1,408 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for provider plugin system.""" + +from importlib import metadata +import os +from pathlib import Path +import subprocess +import sys +import tempfile +import textwrap +import types +from unittest import mock +import uuid + +from absl.testing import absltest +import pytest + +import langextract as lx + + +class PluginSmokeTest(absltest.TestCase): + """Basic smoke tests for plugin loading functionality.""" + + def setUp(self): + super().setUp() + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + self.addCleanup(lx.providers.registry.clear) + self.addCleanup(setattr, lx.providers, "_PLUGINS_LOADED", False) + + def test_plugin_discovery_and_usage(self): + """Test plugin discovery via entry points. + + Entry points can return a class or module. Registration happens via + the @register decorator in both cases. + """ + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-model") + class PluginProvider(lx.inference.BaseLanguageModel): + + def __init__(self, model_id=None, **kwargs): + super().__init__() + self.model_id = model_id + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return PluginProvider + + ep = types.SimpleNamespace( + name="plugin_provider", + group="langextract.providers", + value="my_pkg:PluginProvider", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + resolved_cls = lx.providers.registry.resolve("plugin-model-123") + self.assertEqual( + resolved_cls.__name__, + "PluginProvider", + "Provider should be resolvable after plugin load", + ) + + cfg = lx.factory.ModelConfig(model_id="plugin-model-123") + model = lx.factory.create_model(cfg) + + out = model.infer(["hi"])[0][0].output + self.assertEqual(out, "ok", "Provider should return expected output") + + def test_plugin_disabled_by_env_var(self): + """Test that LANGEXTRACT_DISABLE_PLUGINS=1 prevents plugin loading.""" + + with mock.patch.dict("os.environ", {"LANGEXTRACT_DISABLE_PLUGINS": "1"}): + with mock.patch.object(metadata, "entry_points") as mock_ep: + lx.providers.load_plugins_once() + mock_ep.assert_not_called() + + def test_handles_import_errors_gracefully(self): + """Test that import errors during plugin loading don't crash.""" + + def _bad_load(): + raise ImportError("Plugin not found") + + bad_ep = types.SimpleNamespace( + name="bad_plugin", + group="langextract.providers", + value="bad_pkg:BadProvider", + load=_bad_load, + ) + + with mock.patch.object(metadata, "entry_points", return_value=[bad_ep]): + lx.providers.load_plugins_once() + + providers = lx.providers.registry.list_providers() + self.assertIsInstance( + providers, + list, + "Registry should remain functional after import error", + ) + self.assertEqual( + len(providers), + 0, + "Broken EP should not partially register", + ) + + def test_load_plugins_once_is_idempotent(self): + """Test that load_plugins_once only discovers once.""" + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-model") + class Plugin(lx.inference.BaseLanguageModel): + + def infer(self, *a, **k): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return Plugin + + ep = types.SimpleNamespace( + name="plugin_provider", + group="langextract.providers", + value="pkg:Plugin", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ) as m: + lx.providers.load_plugins_once() + lx.providers.load_plugins_once() # should be a no-op + self.assertEqual(m.call_count, 1, "Discovery should happen only once") + + def test_non_subclass_entry_point_does_not_crash(self): + """Test that non-BaseLanguageModel classes don't crash the system.""" + + class NotAProvider: # pylint: disable=too-few-public-methods + """Dummy class to test non-provider handling.""" + + bad_ep = types.SimpleNamespace( + name="bad", + group="langextract.providers", + value="bad:NotAProvider", + load=lambda: NotAProvider, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [bad_ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + # The system should remain functional even if a bad provider is loaded + # Trying to use it would fail, but discovery shouldn't crash + providers = lx.providers.registry.list_providers() + self.assertIsInstance( + providers, + list, + "Registry should remain functional with bad provider", + ) + with self.assertRaisesRegex(ValueError, "No provider registered"): + lx.providers.registry.resolve("bad") + + def test_plugin_priority_override_core_provider(self): + """Plugin with higher priority should override core provider on conflicts.""" + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + + def _ep_load(): + @lx.providers.registry.register(r"^gemini", priority=50) + class OverrideGemini(lx.inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="override")]] + + return OverrideGemini + + ep = types.SimpleNamespace( + name="override_gemini", + group="langextract.providers", + value="pkg:OverrideGemini", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + # Core gemini registers with priority 10 in providers.gemini + # Our plugin registered with priority 50; it should win. + resolved = lx.providers.registry.resolve("gemini-2.5-flash") + self.assertEqual(resolved.__name__, "OverrideGemini") + + def test_resolve_provider_for_plugin(self): + """resolve_provider should find plugin by class name and name-insensitive.""" + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-resolve") + class ResolveMePlease(lx.inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return ResolveMePlease + + ep = types.SimpleNamespace( + name="resolver_plugin", + group="langextract.providers", + value="pkg:ResolveMePlease", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + cls_by_exact = lx.providers.registry.resolve_provider("ResolveMePlease") + self.assertEqual(cls_by_exact.__name__, "ResolveMePlease") + + cls_by_partial = lx.providers.registry.resolve_provider("resolveme") + self.assertEqual(cls_by_partial.__name__, "ResolveMePlease") + + +class PluginE2ETest(absltest.TestCase): + """End-to-end test with actual pip installation. + + This test is expensive and only runs when explicitly requested + via tox -e plugin-e2e or in CI when provider files change. + """ + + @pytest.mark.requires_pip + @pytest.mark.integration + def test_pip_install_discovery_and_cleanup(self): + """Test complete plugin lifecycle: install, discovery, usage, uninstall. + + This test: + 1. Creates a Python package with a provider plugin + 2. Installs it via pip + 3. Verifies the plugin is discovered and usable + 4. Uninstalls and verifies cleanup + """ + + with tempfile.TemporaryDirectory() as tmpdir: + pkg_name = f"test_langextract_plugin_{uuid.uuid4().hex[:8]}" + pkg_dir = Path(tmpdir) / pkg_name + pkg_dir.mkdir() + + (pkg_dir / pkg_name).mkdir() + (pkg_dir / pkg_name / "__init__.py").write_text("") + + (pkg_dir / pkg_name / "provider.py").write_text(textwrap.dedent(""" + import langextract as lx + + USED_BY_EXTRACT = False + + @lx.providers.registry.register(r'^test-pip-model', priority=50) + class TestPipProvider(lx.inference.BaseLanguageModel): + def __init__(self, model_id, **kwargs): + super().__init__() + self.model_id = model_id + + def infer(self, batch_prompts, **kwargs): + global USED_BY_EXTRACT + USED_BY_EXTRACT = True + return [[lx.inference.ScoredOutput(score=1.0, output="pip test response")]] + """)) + + (pkg_dir / "pyproject.toml").write_text(textwrap.dedent(f""" + [build-system] + requires = ["setuptools>=61.0"] + build-backend = "setuptools.build_meta" + + [project] + name = "{pkg_name}" + version = "0.0.1" + description = "Test plugin for langextract" + + [project.entry-points."langextract.providers"] + test_provider = "{pkg_name}.provider:TestPipProvider" + """)) + + pip_env = { + **os.environ, + "PIP_NO_INPUT": "1", + "PIP_DISABLE_PIP_VERSION_CHECK": "1", + } + result = subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "-e", + str(pkg_dir), + "--no-deps", + "-q", + ], + check=True, + capture_output=True, + text=True, + env=pip_env, + ) + + self.assertEqual(result.returncode, 0, "pip install failed") + self.assertNotIn( + "ERROR", + result.stderr.upper(), + f"pip install had errors: {result.stderr}", + ) + + try: + test_script = Path(tmpdir) / "test_plugin.py" + test_script.write_text(textwrap.dedent(f""" + import langextract as lx + import sys + + lx.providers.load_plugins_once() + + # Test via factory.create_model + cfg = lx.factory.ModelConfig(model_id="test-pip-model-123") + model = lx.factory.create_model(cfg) + result = model.infer(["test prompt"]) + assert result[0][0].output == "pip test response", f"Got: {{result[0][0].output}}" + + # Verify the plugin is resolvable via the registry + resolved = lx.providers.registry.resolve("test-pip-model-xyz") + assert resolved.__name__ == "TestPipProvider", "Plugin should be resolvable" + + from {pkg_name}.provider import USED_BY_EXTRACT + assert USED_BY_EXTRACT, "Provider infer() was not called" + + print("SUCCESS: Plugin test passed") + """)) + + result = subprocess.run( + [sys.executable, str(test_script)], + capture_output=True, + text=True, + check=False, + ) + + self.assertIn( + "SUCCESS", + result.stdout, + f"Test failed. stdout: {result.stdout}, stderr: {result.stderr}", + ) + + finally: + subprocess.run( + [sys.executable, "-m", "pip", "uninstall", "-y", pkg_name], + check=False, + capture_output=True, + env=pip_env, + ) + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + lx.providers.load_plugins_once() + + with self.assertRaisesRegex( + ValueError, "No provider registered for model_id='test-pip-model" + ): + lx.providers.registry.resolve("test-pip-model-789") + + +if __name__ == "__main__": + absltest.main() diff --git a/tox.ini b/tox.ini index 70cdcf12..d9523ff2 100644 --- a/tox.ini +++ b/tox.ini @@ -22,7 +22,7 @@ setenv = deps = .[openai,dev,test] commands = - pytest -ra -m "not live_api" + pytest -ra -m "not live_api and not requires_pip" [testenv:format] skip_install = true @@ -62,3 +62,20 @@ deps = requests>=2.25.0 commands = pytest tests/test_ollama_integration.py -v --tb=short + +[testenv:plugin-integration] +basepython = python3.11 +setenv = + PIP_NO_INPUT = 1 + PIP_DISABLE_PIP_VERSION_CHECK = 1 +deps = + .[dev,test] +commands = + pytest tests/provider_plugin_test.py::PluginE2ETest -v -m "requires_pip" + +[testenv:plugin-smoke] +basepython = python3.11 +deps = + .[dev,test] +commands = + pytest tests/provider_plugin_test.py::PluginSmokeTest -v From 1a24cd0056013191e6ac70a3ecc144b45e848695 Mon Sep 17 00:00:00 2001 From: Mariano Iglesias Date: Mon, 11 Aug 2025 16:57:01 -0300 Subject: [PATCH 10/17] Adding model and config parameters to extract() (#119) * Adding model and config parameters to extract() * Adding extract precedence tests * Applying changes from #120 --- langextract/__init__.py | 90 +++++++----- tests/extract_precedence_test.py | 228 +++++++++++++++++++++++++++++++ tests/init_test.py | 1 - 3 files changed, 284 insertions(+), 35 deletions(-) create mode 100644 tests/extract_precedence_test.py diff --git a/langextract/__init__.py b/langextract/__init__.py index c32ce18d..6ffb6d97 100644 --- a/langextract/__init__.py +++ b/langextract/__init__.py @@ -83,6 +83,8 @@ def extract( debug: bool = True, model_url: str | None = None, extraction_passes: int = 1, + config: factory.ModelConfig | None = None, + model: inference.BaseLanguageModel | None = None, ) -> data.AnnotatedDocument | Iterable[data.AnnotatedDocument]: """Extracts structured information from text. @@ -106,6 +108,7 @@ def extract( monitor usage with small test runs to estimate costs. model_id: The model ID to use for extraction. language_model_type: The type of language model to use for inference. + DEPRECATED in favor of `model_id`, `config` and `model`. format_type: The format type for the output (JSON or YAML). max_char_buffer: Max number of characters for inference. temperature: The sampling temperature for generation. Higher values (e.g., @@ -146,6 +149,10 @@ def extract( for overlaps). WARNING: Each additional pass reprocesses tokens, potentially increasing API costs. For example, extraction_passes=3 reprocesses tokens 3x. + config: Model configuration to use for extraction (favored over + `model_id` and `language_model_type`.) + model: Model to use for extraction (favored over `config`, `model_id` and + `language_model_type`.) Returns: An AnnotatedDocument with the extracted information when input is a @@ -187,51 +194,66 @@ def extract( ) prompt_template.examples.extend(examples) - # Generate schema constraints if enabled - model_schema = None - schema_constraint = None - - # TODO: Unify schema generation. - if ( - use_schema_constraints - and language_model_type == inference.GeminiLanguageModel - ): - model_schema = schema.GeminiSchema.from_examples(prompt_template.examples) - # Handle backward compatibility for language_model_type parameter if language_model_type != inference.GeminiLanguageModel: warnings.warn( "The 'language_model_type' parameter is deprecated and will be removed" " in a future version. The provider is now automatically selected based" - " on the model_id.", + " on the 'model_id' or by the 'config' and/or 'model' parameters.", DeprecationWarning, stacklevel=2, ) - # Use factory to create the language model - base_lm_kwargs: dict[str, Any] = { - "api_key": api_key, - "gemini_schema": model_schema, - "format_type": format_type, - "temperature": temperature, - "model_url": model_url, - "base_url": model_url, # Support both parameter names for Ollama - "constraint": schema_constraint, - "max_workers": max_workers, - } + if use_schema_constraints and (model or config): + warnings.warn( + "The 'use_schema_constraints' parameter is ignored when 'model' or" + " 'config' is provided. To use schema constraints, include them" + " directly in your config's provider_kwargs (e.g., 'gemini_schema' for" + " Gemini models).", + UserWarning, + stacklevel=2, + ) - # Merge user-provided params which have precedence over defaults. - base_lm_kwargs.update(language_model_params or {}) + if not model and not config: + # Generate schema constraints if enabled + model_schema = None - # Filter out None values - filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None} + # TODO: Unify schema generation. + if ( + use_schema_constraints + and language_model_type == inference.GeminiLanguageModel + ): + model_schema = schema.GeminiSchema.from_examples(prompt_template.examples) - # Create model using factory - # Providers are loaded lazily by the registry on first resolve - config = factory.ModelConfig( - model_id=model_id, provider_kwargs=filtered_kwargs - ) - language_model = factory.create_model(config) + # Use factory to create the language model + base_lm_kwargs: dict[str, Any] = { + "api_key": api_key, + "gemini_schema": model_schema, + "format_type": format_type, + "temperature": temperature, + "model_url": model_url, + "base_url": model_url, # Support both parameter names for Ollama + "max_workers": max_workers, + } + + # Merge user-provided params which have precedence over defaults. + base_lm_kwargs.update(language_model_params or {}) + + # Filter out None values + filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None} + + # Create model using factory + # Providers are loaded lazily by the registry on first resolve + config = factory.ModelConfig( + model_id=model_id, provider_kwargs=filtered_kwargs + ) + + if not model: + if not config: + raise RuntimeError( + "Internal error: Failed to determine model configuration" + ) + model = factory.create_model(config) resolver_defaults = { "fence_output": fence_output, @@ -244,7 +266,7 @@ def extract( res = resolver.Resolver(**resolver_defaults) annotator = annotation.Annotator( - language_model=language_model, + language_model=model, prompt_template=prompt_template, format_type=format_type, fence_output=fence_output, diff --git a/tests/extract_precedence_test.py b/tests/extract_precedence_test.py new file mode 100644 index 00000000..fc40613c --- /dev/null +++ b/tests/extract_precedence_test.py @@ -0,0 +1,228 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for parameter precedence in extract().""" + +from unittest import mock + +from absl.testing import absltest + +from langextract import data +from langextract import factory +from langextract import inference +import langextract as lx + + +class ExtractParameterPrecedenceTest(absltest.TestCase): + """Tests ensuring correct precedence among extract() parameters.""" + + def setUp(self): + super().setUp() + self.examples = [ + data.ExampleData( + text="example", + extractions=[ + data.Extraction( + extraction_class="entity", + extraction_text="example", + ) + ], + ) + ] + self.description = "description" + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_model_overrides_all_other_parameters( + self, mock_create_model, mock_annotator_cls + ): + provided_model = mock.MagicMock() + mock_annotator = mock_annotator_cls.return_value + mock_annotator.annotate_text.return_value = "ok" + + config = factory.ModelConfig(model_id="config-id") + + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + model=provided_model, + config=config, + model_id="ignored-model", + api_key="ignored-key", + language_model_type=inference.OpenAILanguageModel, + use_schema_constraints=False, + ) + + mock_create_model.assert_not_called() + _, kwargs = mock_annotator_cls.call_args + self.assertIs(kwargs["language_model"], provided_model) + self.assertEqual(result, "ok") + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_config_overrides_model_id_and_language_model_type( + self, mock_create_model, mock_annotator_cls + ): + config = factory.ModelConfig( + model_id="config-model", provider_kwargs={"api_key": "config-key"} + ) + mock_model = mock.MagicMock() + mock_create_model.return_value = mock_model + mock_annotator = mock_annotator_cls.return_value + mock_annotator.annotate_text.return_value = "ok" + + with mock.patch("langextract.factory.ModelConfig") as mock_model_config: + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + config=config, + model_id="other-model", + api_key="other-key", + language_model_type=inference.OpenAILanguageModel, + use_schema_constraints=False, + ) + mock_model_config.assert_not_called() + + mock_create_model.assert_called_once_with(config) + _, kwargs = mock_annotator_cls.call_args + self.assertIs(kwargs["language_model"], mock_model) + self.assertEqual(config.provider_kwargs, {"api_key": "config-key"}) + self.assertEqual(result, "ok") + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_model_id_and_base_kwargs_override_language_model_type( + self, mock_create_model, mock_annotator_cls + ): + mock_model = mock.MagicMock() + mock_create_model.return_value = mock_model + mock_annotator_cls.return_value.annotate_text.return_value = "ok" + mock_config = mock.MagicMock() + + with mock.patch( + "langextract.factory.ModelConfig", return_value=mock_config + ) as mock_model_config: + with self.assertWarns(DeprecationWarning): + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + model_id="model-123", + api_key="api-key", + temperature=0.9, + model_url="http://model", + language_model_type=inference.OpenAILanguageModel, + use_schema_constraints=False, + ) + + mock_model_config.assert_called_once() + _, kwargs = mock_model_config.call_args + self.assertEqual(kwargs["model_id"], "model-123") + provider_kwargs = kwargs["provider_kwargs"] + self.assertEqual(provider_kwargs["api_key"], "api-key") + self.assertEqual(provider_kwargs["temperature"], 0.9) + self.assertEqual(provider_kwargs["model_url"], "http://model") + self.assertEqual(provider_kwargs["base_url"], "http://model") + mock_create_model.assert_called_once_with(mock_config) + self.assertEqual(result, "ok") + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_language_model_type_only_emits_warning_and_works( + self, mock_create_model, mock_annotator_cls + ): + mock_model = mock.MagicMock() + mock_create_model.return_value = mock_model + mock_annotator_cls.return_value.annotate_text.return_value = "ok" + mock_config = mock.MagicMock() + + with mock.patch( + "langextract.factory.ModelConfig", return_value=mock_config + ) as mock_model_config: + with self.assertWarns(DeprecationWarning): + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + language_model_type=inference.OpenAILanguageModel, + use_schema_constraints=False, + ) + + mock_model_config.assert_called_once() + _, kwargs = mock_model_config.call_args + self.assertEqual(kwargs["model_id"], "gemini-2.5-flash") + mock_create_model.assert_called_once_with(mock_config) + self.assertEqual(result, "ok") + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_use_schema_constraints_warns_with_config( + self, mock_create_model, mock_annotator_cls + ): + """Test that use_schema_constraints emits warning when used with config.""" + config = factory.ModelConfig( + model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test-key"} + ) + + mock_model = mock.MagicMock() + mock_create_model.return_value = mock_model + mock_annotator = mock_annotator_cls.return_value + mock_annotator.annotate_text.return_value = "ok" + + with self.assertWarns(UserWarning) as cm: + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + config=config, + use_schema_constraints=True, + ) + + self.assertIn("use_schema_constraints", str(cm.warning)) + self.assertIn("ignored", str(cm.warning)) + mock_create_model.assert_called_once() + called_config = mock_create_model.call_args[0][0] + self.assertEqual(called_config.model_id, "gemini-2.5-flash") + self.assertNotIn("gemini_schema", called_config.provider_kwargs) + self.assertEqual(result, "ok") + + @mock.patch("langextract.annotation.Annotator") + @mock.patch("langextract.factory.create_model") + def test_use_schema_constraints_warns_with_model( + self, mock_create_model, mock_annotator_cls + ): + """Test that use_schema_constraints emits warning when used with model.""" + provided_model = mock.MagicMock() + mock_annotator = mock_annotator_cls.return_value + mock_annotator.annotate_text.return_value = "ok" + + with self.assertWarns(UserWarning) as cm: + result = lx.extract( + text_or_documents="text", + prompt_description=self.description, + examples=self.examples, + model=provided_model, + use_schema_constraints=True, + ) + + self.assertIn("use_schema_constraints", str(cm.warning)) + self.assertIn("ignored", str(cm.warning)) + mock_create_model.assert_not_called() + self.assertEqual(result, "ok") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/init_test.py b/tests/init_test.py index 23c138d5..b40f0e1d 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -24,7 +24,6 @@ from langextract import prompting from langextract import schema import langextract as lx -from langextract.providers import gemini class InitTest(absltest.TestCase): From 77b7b95cd82e70dde11d348dbaad90360b09f906 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 03:39:44 -0400 Subject: [PATCH 11/17] Refactor schema system to support provider plugins (#130) Enable providers to define custom schema implementations via BaseSchema abstraction. Add property-based fence output, FormatModeSchema for JSON/YAML providers, and move GeminiSchema to providers/schemas/. --- README.md | 25 +- examples/custom_provider_plugin/README.md | 48 +- .../langextract_provider_example/provider.py | 35 +- .../langextract_provider_example/schema.py | 161 ++++++ .../test_example_provider.py | 1 - langextract/__init__.py | 140 ++--- langextract/factory.py | 96 +++- langextract/inference.py | 46 +- langextract/providers/README.md | 2 +- langextract/providers/gemini.py | 61 ++- langextract/providers/ollama.py | 28 +- langextract/providers/schemas/__init__.py | 19 + langextract/providers/schemas/gemini.py | 145 ++++++ langextract/resolver.py | 4 + langextract/schema.py | 183 +++---- pyproject.toml | 2 +- tests/extract_precedence_test.py | 26 +- tests/extract_schema_integration_test.py | 216 ++++++++ tests/factory_schema_test.py | 258 ++++++++++ tests/init_test.py | 52 +- tests/progress_test.py | 2 - tests/provider_plugin_test.py | 266 +++++++++- tests/provider_schema_test.py | 479 ++++++++++++++++++ tests/registry_test.py | 2 - tests/schema_test.py | 97 ++++ tests/visualization_test.py | 1 - 26 files changed, 2182 insertions(+), 213 deletions(-) create mode 100644 examples/custom_provider_plugin/langextract_provider_example/schema.py create mode 100644 langextract/providers/schemas/__init__.py create mode 100644 langextract/providers/schemas/gemini.py create mode 100644 tests/extract_schema_integration_test.py create mode 100644 tests/factory_schema_test.py create mode 100644 tests/provider_schema_test.py diff --git a/README.md b/README.md index b373ee00..5c45e521 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - [Quick Start](#quick-start) - [Installation](#installation) - [API Key Setup for Cloud Models](#api-key-setup-for-cloud-models) +- [Adding Custom Model Providers](#adding-custom-model-providers) - [Using OpenAI Models](#using-openai-models) - [Using Local LLMs with Ollama](#using-local-llms-with-ollama) - [More Examples](#more-examples) @@ -253,6 +254,22 @@ result = lx.extract( ) ``` +## Adding Custom Model Providers + +LangExtract supports custom LLM providers via a lightweight plugin system. You can add support for new models without changing core code. + +- Add new model support independently of the core library +- Distribute your provider as a separate Python package +- Keep custom dependencies isolated +- Override or extend built-in providers via priority-based resolution + +See the detailed guide in [Provider System Documentation](langextract/providers/README.md) to learn how to: + +- Register a provider with `@registry.register(...)` +- Publish an entry point for discovery +- Optionally provide a schema with `get_schema_class()` for structured output +- Integrate with the factory via `create_model(...)` + ## Using OpenAI Models LangExtract supports OpenAI models (requires optional dependency: `pip install langextract[openai]`): @@ -274,7 +291,6 @@ result = lx.extract( Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet. ## Using Local LLMs with Ollama - LangExtract supports local inference using Ollama, allowing you to run models without API keys: ```python @@ -326,14 +342,7 @@ with development, testing, and pull requests. You must sign a [Contributor License Agreement](https://cla.developers.google.com/about) before submitting patches. -### Adding Custom Model Providers - -LangExtract supports custom LLM providers through a plugin system. You can add support for new models by creating an external Python package that registers with LangExtract's provider registry. This allows you to: -- Add new model support without modifying the core library -- Distribute your provider independently -- Maintain custom dependencies -For detailed instructions, see the [Provider System Documentation](langextract/providers/README.md). ## Testing diff --git a/examples/custom_provider_plugin/README.md b/examples/custom_provider_plugin/README.md index b8cd299b..9b64f078 100644 --- a/examples/custom_provider_plugin/README.md +++ b/examples/custom_provider_plugin/README.md @@ -12,7 +12,8 @@ custom_provider_plugin/ ├── README.md # This file ├── langextract_provider_example/ # Package directory │ ├── __init__.py # Package initialization -│ └── provider.py # Custom provider implementation +│ ├── provider.py # Custom provider implementation +│ └── schema.py # Custom schema implementation (optional) └── test_example_provider.py # Test script ``` @@ -41,6 +42,51 @@ custom_gemini = "langextract_provider_example:CustomGeminiProvider" This entry point allows LangExtract to automatically discover your provider. +### Custom Schema Support (`schema.py`) + +Providers can optionally implement custom schemas for structured output: + +**Flow:** Examples → `from_examples()` → `to_provider_config()` → Provider kwargs → Inference + +```python +class CustomProviderSchema(lx.schema.BaseSchema): + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + # Analyze examples to find patterns + # Build schema based on extraction classes and attributes seen + return cls(schema_dict) + + def to_provider_config(self): + # Convert schema to provider kwargs + return { + "response_schema": self._schema_dict, + "enable_structured_output": True + } + + @property + def supports_strict_mode(self): + # True = valid JSON output, no markdown fences needed + return True +``` + +Then in your provider: + +```python +class CustomProvider(lx.inference.BaseLanguageModel): + @classmethod + def get_schema_class(cls): + return CustomProviderSchema # Tell LangExtract about your schema + + def __init__(self, **kwargs): + # Receive schema config in kwargs when use_schema_constraints=True + self.response_schema = kwargs.get('response_schema') + + def infer(self, batch_prompts, **kwargs): + # Use schema during API calls + if self.response_schema: + config['response_schema'] = self.response_schema +``` + ## Installation ```bash diff --git a/examples/custom_provider_plugin/langextract_provider_example/provider.py b/examples/custom_provider_plugin/langextract_provider_example/provider.py index fb7317a8..4e67f6df 100644 --- a/examples/custom_provider_plugin/langextract_provider_example/provider.py +++ b/examples/custom_provider_plugin/langextract_provider_example/provider.py @@ -19,6 +19,8 @@ import dataclasses from typing import Any, Iterator, Sequence +from langextract_provider_example import schema as custom_schema + import langextract as lx @@ -30,9 +32,9 @@ class CustomGeminiProvider(lx.inference.BaseLanguageModel): """Example custom LangExtract provider implementation. This demonstrates how to create a custom provider for LangExtract - that can intercept and handle model requests. This example uses - Gemini as the backend, but you would replace this with your own - API or model implementation. + that can intercept and handle model requests. This example wraps + the actual Gemini API to show how custom schemas integrate, but you + would replace the Gemini calls with your own API or model implementation. Note: Since this registers the same pattern as the default Gemini provider, you must explicitly specify this provider when creating a model: @@ -47,6 +49,8 @@ class CustomGeminiProvider(lx.inference.BaseLanguageModel): model_id: str api_key: str | None temperature: float + response_schema: dict[str, Any] | None = None + enable_structured_output: bool = False _client: Any = dataclasses.field(repr=False, compare=False) def __init__( @@ -77,6 +81,12 @@ def __init__( self.api_key = api_key self.temperature = temperature + # Schema kwargs from CustomProviderSchema.to_provider_config() + self.response_schema = kwargs.get('response_schema') + self.enable_structured_output = kwargs.get( + 'enable_structured_output', False + ) + # Store any additional kwargs for potential use self._extra_kwargs = kwargs @@ -89,6 +99,18 @@ def __init__( super().__init__() + @classmethod + def get_schema_class(cls) -> type[lx.schema.BaseSchema] | None: + """Return our custom schema class. + + This allows LangExtract to use our custom schema implementation + when use_schema_constraints=True is specified. + + Returns: + Our custom schema class that will be used to generate constraints. + """ + return custom_schema.CustomProviderSchema + def infer( self, batch_prompts: Sequence[str], **kwargs: Any ) -> Iterator[Sequence[lx.inference.ScoredOutput]]: @@ -110,6 +132,13 @@ def infer( if key in kwargs: config[key] = kwargs[key] + # Apply schema constraints if configured + if self.response_schema and self.enable_structured_output: + # For Gemini, this ensures the model outputs JSON matching our schema + # Adapt this section based on your actual provider's API requirements + config['response_schema'] = self.response_schema + config['response_mime_type'] = 'application/json' + for prompt in batch_prompts: try: # TODO: Replace this with your own API/model calls diff --git a/examples/custom_provider_plugin/langextract_provider_example/schema.py b/examples/custom_provider_plugin/langextract_provider_example/schema.py new file mode 100644 index 00000000..8b742f46 --- /dev/null +++ b/examples/custom_provider_plugin/langextract_provider_example/schema.py @@ -0,0 +1,161 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example custom schema implementation for provider plugins.""" + +from __future__ import annotations + +from typing import Any, Sequence + +import langextract as lx + + +class CustomProviderSchema(lx.schema.BaseSchema): + """Example custom schema implementation for a provider plugin. + + This demonstrates how plugins can provide their own schema implementations + that integrate with LangExtract's schema system. Custom schemas allow + providers to: + + 1. Generate provider-specific constraints from examples + 2. Control output formatting and validation + 3. Optimize for their specific model capabilities + + This example generates a JSON schema from the examples and passes it to + the Gemini backend (which this example provider wraps) for structured output. + """ + + def __init__(self, schema_dict: dict[str, Any], strict_mode: bool = True): + """Initialize the custom schema. + + Args: + schema_dict: The generated JSON schema dictionary. + strict_mode: Whether the provider guarantees valid output. + """ + self._schema_dict = schema_dict + self._strict_mode = strict_mode + + @classmethod + def from_examples( + cls, + examples_data: Sequence[lx.data.ExampleData], + attribute_suffix: str = "_attributes", + ) -> CustomProviderSchema: + """Generate schema from example data. + + This method analyzes the provided examples to build a schema that + captures the structure of expected extractions. Called automatically + by LangExtract when use_schema_constraints=True. + + Args: + examples_data: Example extractions to learn from. + attribute_suffix: Suffix for attribute fields (unused in this example). + + Returns: + A configured CustomProviderSchema instance. + + Example: + If examples contain extractions with class "condition" and attribute + "severity", the schema will constrain the model to only output those + specific classes and attributes. + """ + extraction_classes = set() + attribute_keys = set() + + for example in examples_data: + for extraction in example.extractions: + extraction_classes.add(extraction.extraction_class) + if extraction.attributes: + attribute_keys.update(extraction.attributes.keys()) + + schema_dict = { + "type": "object", + "properties": { + "extractions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "extraction_class": { + "type": "string", + "enum": ( + list(extraction_classes) + if extraction_classes + else None + ), + }, + "extraction_text": {"type": "string"}, + "attributes": { + "type": "object", + "properties": { + key: {"type": "string"} + for key in attribute_keys + }, + }, + }, + "required": ["extraction_class", "extraction_text"], + }, + }, + }, + "required": ["extractions"], + } + + # Remove enum if no classes found + if not extraction_classes: + del schema_dict["properties"]["extractions"]["items"]["properties"][ + "extraction_class" + ]["enum"] + + return cls(schema_dict, strict_mode=True) + + def to_provider_config(self) -> dict[str, Any]: + """Convert schema to provider-specific configuration. + + This is called after from_examples() and returns kwargs that will be + passed to the provider's __init__ method. The provider can then use + these during inference. + + Returns: + Dictionary of provider kwargs that will be passed to the model. + In this example, we return both the schema and a flag to enable + structured output mode. + + Note: + These kwargs are merged with user-provided kwargs, with user values + taking precedence (caller-wins merge semantics). + """ + return { + "response_schema": self._schema_dict, + "enable_structured_output": True, + "output_format": "json", + } + + @property + def supports_strict_mode(self) -> bool: + """Whether this schema guarantees valid structured output. + + Returns: + True if the provider will emit valid JSON without needing + Markdown fences for extraction. + """ + return self._strict_mode + + @property + def schema_dict(self) -> dict[str, Any]: + """Access the underlying schema dictionary. + + Returns: + The JSON schema dictionary. + """ + return self._schema_dict diff --git a/examples/custom_provider_plugin/test_example_provider.py b/examples/custom_provider_plugin/test_example_provider.py index 13ef8494..319c216f 100644 --- a/examples/custom_provider_plugin/test_example_provider.py +++ b/examples/custom_provider_plugin/test_example_provider.py @@ -33,7 +33,6 @@ def main(): print("Set GEMINI_API_KEY or LANGEXTRACT_API_KEY to test") return - # Create model using explicit provider selection config = lx.factory.ModelConfig( model_id="gemini-2.5-flash", provider="CustomGeminiProvider", diff --git a/langextract/__init__.py b/langextract/__init__.py index 6ffb6d97..9509b81e 100644 --- a/langextract/__init__.py +++ b/langextract/__init__.py @@ -73,7 +73,7 @@ def extract( format_type: data.FormatType = data.FormatType.JSON, max_char_buffer: int = 1000, temperature: float = 0.5, - fence_output: bool = False, + fence_output: bool | None = None, use_schema_constraints: bool = True, batch_length: int = 10, max_workers: int = 10, @@ -107,8 +107,10 @@ def extract( additional token costs. Refer to your API provider's pricing details and monitor usage with small test runs to estimate costs. model_id: The model ID to use for extraction. - language_model_type: The type of language model to use for inference. - DEPRECATED in favor of `model_id`, `config` and `model`. + language_model_type: [DEPRECATED] The type of language model to use for + inference. Warning triggers when value differs from the legacy default + (GeminiLanguageModel). This parameter will be removed in v2.0.0. Use + the model, config, or model_id parameters instead. format_type: The format type for the output (JSON or YAML). max_char_buffer: Max number of characters for inference. temperature: The sampling temperature for generation. Higher values (e.g., @@ -116,9 +118,11 @@ def extract( reducing repetitive outputs. Defaults to 0.5. fence_output: Whether to expect/generate fenced output (```json or ```yaml). When True, the model is prompted to generate fenced output and - the resolver expects it. When False, raw JSON/YAML is expected. If your - model utilizes schema constraints, this can generally be set to False - unless the constraint also accounts for code fence delimiters. + the resolver expects it. When False, raw JSON/YAML is expected. When None, + automatically determined based on provider schema capabilities: if a schema + is applied and supports_strict_mode is True, defaults to False; otherwise + True. If your model utilizes schema constraints, this can generally be set + to False unless the constraint also accounts for code fence delimiters. use_schema_constraints: Whether to generate schema constraints for models. For supported models, this enables structured outputs. Defaults to True. batch_length: Number of text chunks processed per batch. Higher values @@ -149,10 +153,11 @@ def extract( for overlaps). WARNING: Each additional pass reprocesses tokens, potentially increasing API costs. For example, extraction_passes=3 reprocesses tokens 3x. - config: Model configuration to use for extraction (favored over - `model_id` and `language_model_type`.) - model: Model to use for extraction (favored over `config`, `model_id` and - `language_model_type`.) + config: Model configuration to use for extraction. Takes precedence over + model_id, api_key, and language_model_type parameters. When both model + and config are provided, model takes precedence. + model: Pre-configured language model to use for extraction. Takes + precedence over all other parameters including config. Returns: An AnnotatedDocument with the extracted information when input is a @@ -170,19 +175,11 @@ def extract( " one ExampleData object with sample extractions." ) - if use_schema_constraints and fence_output: - warnings.warn( - "When `use_schema_constraints` is True and `fence_output` is True, " - "ensure that your schema constraint includes the code fence " - "delimiters, or set `fence_output` to False.", - UserWarning, - ) - if max_workers is not None and batch_length < max_workers: warnings.warn( - f"batch_length ({batch_length}) is less than max_workers" - f" ({max_workers}). Only {batch_length} workers will be used. For" - " optimal parallelization, set batch_length >= max_workers.", + f"batch_length ({batch_length}) < max_workers ({max_workers}). " + f"Only {batch_length} workers will be used. " + "Set batch_length >= max_workers for optimal parallelization.", UserWarning, ) @@ -194,66 +191,77 @@ def extract( ) prompt_template.examples.extend(examples) - # Handle backward compatibility for language_model_type parameter - if language_model_type != inference.GeminiLanguageModel: - warnings.warn( - "The 'language_model_type' parameter is deprecated and will be removed" - " in a future version. The provider is now automatically selected based" - " on the 'model_id' or by the 'config' and/or 'model' parameters.", - DeprecationWarning, - stacklevel=2, - ) + language_model = None + + if model: + language_model = model + if fence_output is not None: + language_model.set_fence_output(fence_output) + if use_schema_constraints: + warnings.warn( + "'use_schema_constraints' is ignored when 'model' is provided. " + "The model should already be configured with schema constraints.", + UserWarning, + stacklevel=2, + ) + elif config: + if use_schema_constraints: + warnings.warn( + "With 'config', schema constraints are still applied via examples. " + "Or pass explicit schema in config.provider_kwargs.", + UserWarning, + stacklevel=2, + ) - if use_schema_constraints and (model or config): - warnings.warn( - "The 'use_schema_constraints' parameter is ignored when 'model' or" - " 'config' is provided. To use schema constraints, include them" - " directly in your config's provider_kwargs (e.g., 'gemini_schema' for" - " Gemini models).", - UserWarning, - stacklevel=2, + language_model = factory.create_model( + config=config, + examples=prompt_template.examples if use_schema_constraints else None, + use_schema_constraints=use_schema_constraints, + fence_output=fence_output, ) + else: + if language_model_type != inference.GeminiLanguageModel: + warnings.warn( + "'language_model_type' is deprecated and will be removed in v2.0.0. " + "Use model, config, or model_id parameters instead.", + DeprecationWarning, + stacklevel=2, + ) - if not model and not config: - # Generate schema constraints if enabled - model_schema = None - - # TODO: Unify schema generation. - if ( - use_schema_constraints - and language_model_type == inference.GeminiLanguageModel - ): - model_schema = schema.GeminiSchema.from_examples(prompt_template.examples) - - # Use factory to create the language model base_lm_kwargs: dict[str, Any] = { "api_key": api_key, - "gemini_schema": model_schema, "format_type": format_type, "temperature": temperature, "model_url": model_url, - "base_url": model_url, # Support both parameter names for Ollama + "base_url": model_url, "max_workers": max_workers, } - # Merge user-provided params which have precedence over defaults. - base_lm_kwargs.update(language_model_params or {}) + # TODO(v2.0.0): Remove gemini_schema parameter + if "gemini_schema" in (language_model_params or {}): + warnings.warn( + "'gemini_schema' is deprecated. Schema constraints are now " + "automatically handled. This parameter will be ignored.", + DeprecationWarning, + stacklevel=2, + ) + language_model_params = dict(language_model_params or {}) + language_model_params.pop("gemini_schema", None) - # Filter out None values + base_lm_kwargs.update(language_model_params or {}) filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None} - - # Create model using factory - # Providers are loaded lazily by the registry on first resolve config = factory.ModelConfig( model_id=model_id, provider_kwargs=filtered_kwargs ) - if not model: - if not config: - raise RuntimeError( - "Internal error: Failed to determine model configuration" - ) - model = factory.create_model(config) + language_model = factory.create_model( + config=config, + examples=prompt_template.examples if use_schema_constraints else None, + use_schema_constraints=use_schema_constraints, + fence_output=fence_output, + ) + + fence_output = language_model.requires_fence_output resolver_defaults = { "fence_output": fence_output, @@ -266,7 +274,7 @@ def extract( res = resolver.Resolver(**resolver_defaults) annotator = annotation.Annotator( - language_model=model, + language_model=language_model, prompt_template=prompt_template, format_type=format_type, fence_output=fence_output, @@ -281,6 +289,7 @@ def extract( additional_context=additional_context, debug=debug, extraction_passes=extraction_passes, + max_workers=max_workers, ) else: documents = cast(Iterable[data.Document], text_or_documents) @@ -291,4 +300,5 @@ def extract( batch_length=batch_length, debug=debug, extraction_passes=extraction_passes, + max_workers=max_workers, ) diff --git a/langextract/factory.py b/langextract/factory.py index 8e6fc7e8..161eb8e8 100644 --- a/langextract/factory.py +++ b/langextract/factory.py @@ -87,20 +87,42 @@ def _kwargs_with_environment_defaults( return resolved -def create_model(config: ModelConfig) -> inference.BaseLanguageModel: +def create_model( + config: ModelConfig, + examples: typing.Sequence[typing.Any] | None = None, + use_schema_constraints: bool = False, + fence_output: bool | None = None, + return_fence_output: bool = False, +) -> inference.BaseLanguageModel | tuple[inference.BaseLanguageModel, bool]: """Create a language model instance from configuration. Args: config: Model configuration with optional model_id and/or provider. + examples: Optional examples for schema generation (if use_schema_constraints=True). + use_schema_constraints: Whether to apply schema constraints from examples. + fence_output: Explicit fence output preference. If None, computed from schema. + return_fence_output: If True, also return computed fence_output value. Returns: An instantiated language model provider. + If return_fence_output=True: Tuple of (model, model.requires_fence_output). Raises: ValueError: If neither model_id nor provider is specified. ValueError: If no provider is registered for the model_id. InferenceConfigError: If provider instantiation fails. """ + if use_schema_constraints or fence_output is not None: + model = _create_model_with_schema( + config=config, + examples=examples, + use_schema_constraints=use_schema_constraints, + fence_output=fence_output, + ) + if return_fence_output: + return model, model.requires_fence_output + return model + if not config.model_id and not config.provider: raise ValueError("Either model_id or provider must be specified") @@ -129,7 +151,10 @@ def create_model(config: ModelConfig) -> inference.BaseLanguageModel: kwargs["model_id"] = model_id try: - return provider_class(**kwargs) + model = provider_class(**kwargs) + if return_fence_output: + return model, model.requires_fence_output + return model except (ValueError, TypeError) as e: raise exceptions.InferenceConfigError( f"Failed to create provider {provider_class.__name__}: {e}" @@ -155,3 +180,70 @@ def create_model_from_id( model_id=model_id, provider=provider, provider_kwargs=provider_kwargs ) return create_model(config) + + +def _create_model_with_schema( + config: ModelConfig, + examples: typing.Sequence[typing.Any] | None = None, + use_schema_constraints: bool = True, + fence_output: bool | None = None, +) -> inference.BaseLanguageModel: + """Internal helper to create a model with optional schema constraints. + + This function creates a language model and optionally configures it with + schema constraints derived from the provided examples. It also computes + appropriate fence defaulting based on the schema's capabilities. + + Args: + config: Model configuration with model_id and/or provider. + examples: Optional sequence of ExampleData for schema generation. + use_schema_constraints: Whether to generate and apply schema constraints. + fence_output: Whether to wrap output in markdown fences. If None, + will be computed based on schema's supports_strict_mode. + + Returns: + A model instance with fence_output configured appropriately. + """ + + if config.provider: + provider_class = registry.resolve_provider(config.provider) + else: + providers.load_builtins_once() + providers.load_plugins_once() + provider_class = registry.resolve(config.model_id) + + schema_instance = None + if use_schema_constraints and examples: + schema_class = provider_class.get_schema_class() + if schema_class is not None: + schema_instance = schema_class.from_examples(examples) + + if schema_instance: + kwargs = schema_instance.to_provider_config() + kwargs.update(config.provider_kwargs) + else: + kwargs = dict(config.provider_kwargs) + + if schema_instance: + schema_instance.sync_with_provider_kwargs(kwargs) + + # Add environment defaults + model_id = config.model_id + kwargs = _kwargs_with_environment_defaults( + model_id or config.provider or "", kwargs + ) + + if model_id: + kwargs["model_id"] = model_id + + try: + model = provider_class(**kwargs) + except (ValueError, TypeError) as e: + raise exceptions.InferenceConfigError( + f"Failed to create provider {provider_class.__name__}: {e}" + ) from e + + model.apply_schema(schema_instance) + model.set_fence_output(fence_output) + + return model diff --git a/langextract/inference.py b/langextract/inference.py index b43977bc..9b9dadbb 100644 --- a/langextract/inference.py +++ b/langextract/inference.py @@ -71,6 +71,50 @@ def __init__(self, constraint: schema.Constraint = schema.Constraint()): constraint. """ self._constraint = constraint + self._schema: schema.BaseSchema | None = None + self._fence_output_override: bool | None = None + + @classmethod + def get_schema_class(cls) -> type[schema.BaseSchema] | None: + """Return the schema class this provider supports.""" + return None + + def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: + """Apply a schema instance to this provider. + + Optional method that providers can override to store the schema instance + for runtime use. The default implementation stores it as _schema. + + Args: + schema_instance: The schema instance to apply, or None to clear. + """ + self._schema = schema_instance + + def set_fence_output(self, fence_output: bool | None) -> None: + """Set explicit fence output preference. + + Args: + fence_output: True to force fences, False to disable, None for auto. + """ + if not hasattr(self, '_fence_output_override'): + self._fence_output_override = None + self._fence_output_override = fence_output + + @property + def requires_fence_output(self) -> bool: + """Whether this model requires fence output for parsing. + + Uses explicit override if set, otherwise computes from schema. + Returns True if no schema or schema doesn't support strict mode. + """ + if ( + hasattr(self, '_fence_output_override') + and self._fence_output_override is not None + ): + return self._fence_output_override + if not hasattr(self, '_schema') or self._schema is None: + return True + return not self._schema.supports_strict_mode @abc.abstractmethod def infer( @@ -245,7 +289,7 @@ def parse_output(self, output: str) -> Any: 'Use langextract.providers.openai.OpenAILanguageModel instead. ' 'Will be removed in v2.0.0.' ) -class OpenAILanguageModel(BaseLanguageModel): # pylint: disable=too-many-instance-attributes +class OpenAILanguageModel(BaseLanguageModel): """Language model inference using OpenAI's API with structured output. DEPRECATED: Use langextract.providers.openai.OpenAILanguageModel instead. diff --git a/langextract/providers/README.md b/langextract/providers/README.md index 66e6899d..7596e7af 100644 --- a/langextract/providers/README.md +++ b/langextract/providers/README.md @@ -199,7 +199,7 @@ from langextract import factory # Specify both model and provider (useful when multiple providers support same model) config = factory.ModelConfig( - model_id="llama3.2:1b", + model_id="gemma2:2b", provider="OllamaLanguageModel", # Explicitly use Ollama provider_kwargs={ "model_url": "http://localhost:11434" diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index 09ce2ec8..c995dc40 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -47,6 +47,26 @@ class GeminiLanguageModel(inference.BaseLanguageModel): default_factory=dict, repr=False, compare=False ) + @classmethod + def get_schema_class(cls) -> type[schema.BaseSchema] | None: + """Return the GeminiSchema class for structured output support. + + Returns: + The GeminiSchema class that supports strict schema constraints. + """ + return schema.GeminiSchema + + def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: + """Apply a schema instance to this provider. + + Args: + schema_instance: The schema instance to apply, or None to clear. + """ + super().apply_schema(schema_instance) + # Keep provider behavior consistent with legacy path + if isinstance(schema_instance, schema.GeminiSchema): + self.gemini_schema = schema_instance + def __init__( self, model_id: str = 'gemini-2.5-flash', @@ -69,8 +89,10 @@ def __init__( max_workers: Maximum number of parallel API calls. fence_output: Whether to wrap output in markdown fences (ignored, Gemini handles this based on schema). - **kwargs: Ignored extra parameters so callers can pass a superset of - arguments shared across back-ends without raising ``TypeError``. + **kwargs: Additional Gemini API parameters. Only allowlisted keys are + forwarded to the API (response_schema, response_mime_type, tools, + safety_settings, stop_sequences, candidate_count, system_instruction). + See https://ai.google.dev/api/generate-content for details. """ try: # pylint: disable=import-outside-toplevel @@ -86,10 +108,19 @@ def __init__( self.format_type = format_type self.temperature = temperature self.max_workers = max_workers - self.fence_output = ( - fence_output # Store but may not use depending on schema - ) - self._extra_kwargs = kwargs or {} + self.fence_output = fence_output + api_config_keys = { + 'response_schema', + 'response_mime_type', + 'tools', + 'safety_settings', + 'stop_sequences', + 'candidate_count', + 'system_instruction', + } + self._extra_kwargs = { + k: v for k, v in (kwargs or {}).items() if k in api_config_keys + } if not self.api_key: raise exceptions.InferenceConfigError('API key not provided for Gemini.') @@ -105,15 +136,17 @@ def _process_single_prompt( ) -> inference.ScoredOutput: """Process a single prompt and return a ScoredOutput.""" try: + if self._extra_kwargs: + config.update(self._extra_kwargs) if self.gemini_schema: - response_schema = self.gemini_schema.schema_dict - mime_type = ( - 'application/json' - if self.format_type == data.FormatType.JSON - else 'application/yaml' - ) - config['response_mime_type'] = mime_type - config['response_schema'] = response_schema + # Gemini structured output only supports JSON + if self.format_type != data.FormatType.JSON: + raise exceptions.InferenceConfigError( + 'Gemini structured output only supports JSON format. ' + 'Set format_type=JSON or use_schema_constraints=False.' + ) + config.setdefault('response_mime_type', 'application/json') + config.setdefault('response_schema', self.gemini_schema.schema_dict) response = self._client.models.generate_content( model=self.model_id, contents=prompt, config=config # type: ignore[arg-type] diff --git a/langextract/providers/ollama.py b/langextract/providers/ollama.py index dddf2d55..f3fd034f 100644 --- a/langextract/providers/ollama.py +++ b/langextract/providers/ollama.py @@ -63,13 +63,22 @@ class OllamaLanguageModel(inference.BaseLanguageModel): default_factory=dict, repr=False, compare=False ) + @classmethod + def get_schema_class(cls) -> type[schema.BaseSchema] | None: + """Return the FormatModeSchema class for JSON output support. + + Returns: + The FormatModeSchema class that enables JSON mode (non-strict). + """ + return schema.FormatModeSchema + def __init__( self, model_id: str, model_url: str = _OLLAMA_DEFAULT_MODEL_URL, - base_url: str | None = None, # Support both model_url and base_url + base_url: str | None = None, # Alias for model_url format_type: data.FormatType | None = None, - structured_output_format: str | None = None, # Deprecated parameter + structured_output_format: str | None = None, # Deprecated constraint: schema.Constraint = schema.Constraint(), **kwargs, ) -> None: @@ -89,9 +98,8 @@ def __init__( # Handle deprecated structured_output_format parameter if structured_output_format is not None: warnings.warn( - "The 'structured_output_format' parameter is deprecated and will be" - " removed in v2.0.0. Use 'format_type' instead with" - ' data.FormatType.JSON or data.FormatType.YAML.', + "'structured_output_format' is deprecated and will be removed in " + "v2.0.0. Use 'format_type' instead.", DeprecationWarning, stacklevel=2, ) @@ -103,6 +111,12 @@ def __init__( else data.FormatType.YAML ) + fmt = kwargs.pop('format', None) + if format_type is None and fmt in ('json', 'yaml'): + format_type = ( + data.FormatType.JSON if fmt == 'json' else data.FormatType.YAML + ) + # Default to JSON if neither parameter was provided if format_type is None: format_type = data.FormatType.JSON @@ -161,6 +175,7 @@ def _ollama_query( keep_alive: int = 5 * 60, num_threads: int | None = None, num_ctx: int = 2048, + **kwargs, # pylint: disable=unused-argument ) -> Mapping[str, Any]: """Sends a prompt to an Ollama model and returns the generated response. @@ -263,8 +278,7 @@ def _ollama_query( return response.json() if response.status_code == 404: raise exceptions.InferenceConfigError( - f"Can't find Ollama {model}. Try launching `ollama run {model}`" - ' from command line.' + f"Can't find Ollama {model}. Try: ollama run {model}" ) else: msg = f'Bad status code from Ollama: {response.status_code}' diff --git a/langextract/providers/schemas/__init__.py b/langextract/providers/schemas/__init__.py new file mode 100644 index 00000000..d36aa499 --- /dev/null +++ b/langextract/providers/schemas/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provider-specific schema implementations.""" + +from langextract.providers.schemas.gemini import GeminiSchema + +__all__ = ["GeminiSchema"] diff --git a/langextract/providers/schemas/gemini.py b/langextract/providers/schemas/gemini.py new file mode 100644 index 00000000..f6cd9e2e --- /dev/null +++ b/langextract/providers/schemas/gemini.py @@ -0,0 +1,145 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemini provider schema implementation.""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +from typing import Any + +from langextract import data +from langextract import schema + +EXTRACTIONS_KEY = schema.EXTRACTIONS_KEY + + +@dataclasses.dataclass +class GeminiSchema(schema.BaseSchema): + """Schema implementation for Gemini structured output. + + Converts ExampleData objects into an OpenAPI/JSON-schema definition + that Gemini can interpret via 'response_schema'. + """ + + _schema_dict: dict + + @property + def schema_dict(self) -> dict: + """Returns the schema dictionary.""" + return self._schema_dict + + @schema_dict.setter + def schema_dict(self, schema_dict: dict) -> None: + """Sets the schema dictionary.""" + self._schema_dict = schema_dict + + def to_provider_config(self) -> dict[str, Any]: + """Convert schema to Gemini-specific configuration. + + Returns: + Dictionary with response_schema and response_mime_type for Gemini API. + """ + return { + "response_schema": self._schema_dict, + "response_mime_type": "application/json", + } + + @property + def supports_strict_mode(self) -> bool: + """Gemini enforces strict JSON schema constraints. + + Returns: + True, as Gemini can enforce structure strictly via response_schema. + """ + return True + + @classmethod + def from_examples( + cls, + examples_data: Sequence[data.ExampleData], + attribute_suffix: str = "_attributes", + ) -> GeminiSchema: + """Creates a GeminiSchema from example extractions. + + Builds a JSON-based schema with a top-level "extractions" array. Each + element in that array is an object containing the extraction class name + and an accompanying "_attributes" object for its attributes. + + Args: + examples_data: A sequence of ExampleData objects containing extraction + classes and attributes. + attribute_suffix: String appended to each class name to form the + attributes field name (defaults to "_attributes"). + + Returns: + A GeminiSchema with internal dictionary represents the JSON constraint. + """ + # Track attribute types for each category + extraction_categories: dict[str, dict[str, set[type]]] = {} + for example in examples_data: + for extraction in example.extractions: + category = extraction.extraction_class + if category not in extraction_categories: + extraction_categories[category] = {} + + if extraction.attributes: + for attr_name, attr_value in extraction.attributes.items(): + if attr_name not in extraction_categories[category]: + extraction_categories[category][attr_name] = set() + extraction_categories[category][attr_name].add(type(attr_value)) + + extraction_properties: dict[str, dict[str, Any]] = {} + + for category, attrs in extraction_categories.items(): + extraction_properties[category] = {"type": "string"} + + attributes_field = f"{category}{attribute_suffix}" + attr_properties = {} + + # Default property for categories without attributes + if not attrs: + attr_properties["_unused"] = {"type": "string"} + else: + for attr_name, attr_types in attrs.items(): + # List attributes become arrays + if list in attr_types: + attr_properties[attr_name] = { + "type": "array", + "items": {"type": "string"}, + } + else: + attr_properties[attr_name] = {"type": "string"} + + extraction_properties[attributes_field] = { + "type": "object", + "properties": attr_properties, + "nullable": True, + } + + extraction_schema = { + "type": "object", + "properties": extraction_properties, + } + + schema_dict = { + "type": "object", + "properties": { + EXTRACTIONS_KEY: {"type": "array", "items": extraction_schema} + }, + "required": [EXTRACTIONS_KEY], + } + + return cls(_schema_dict=schema_dict) diff --git a/langextract/resolver.py b/langextract/resolver.py index c6496b82..e6ee4533 100644 --- a/langextract/resolver.py +++ b/langextract/resolver.py @@ -117,6 +117,7 @@ def align( enable_fuzzy_alignment: bool = True, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, accept_match_lesser: bool = True, + **kwargs, ) -> Iterator[data.Extraction]: """Aligns extractions with source text, setting token/char intervals and alignment status. @@ -143,6 +144,7 @@ def align( (0-1). accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER status). + **kwargs: Additional keyword arguments for provider-specific alignment. Yields: Aligned extractions with updated token intervals and alignment status. @@ -245,6 +247,7 @@ def align( enable_fuzzy_alignment: bool = True, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, accept_match_lesser: bool = True, + **kwargs, ) -> Iterator[data.Extraction]: """Aligns annotated extractions with source text. @@ -264,6 +267,7 @@ def align( alignment. accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER status). + **kwargs: Additional parameters. Yields: Iterator on aligned extractions. diff --git a/langextract/schema.py b/langextract/schema.py index dd553bdc..08764d67 100644 --- a/langextract/schema.py +++ b/langextract/schema.py @@ -24,6 +24,8 @@ from langextract import data +EXTRACTIONS_KEY = "extractions" # Shared key for extraction arrays in JSON/YAML + class ConstraintType(enum.Enum): """Enumeration of constraint types.""" @@ -31,7 +33,7 @@ class ConstraintType(enum.Enum): NONE = "none" -# TODO: Remove and decouple Constraint and ConstraintType from Schema class. +# TODO(v2.0.0): Remove and decouple Constraint and ConstraintType from Schema class. @dataclasses.dataclass class Constraint: """Represents a constraint for model output decoding. @@ -43,9 +45,6 @@ class Constraint: constraint_type: ConstraintType = ConstraintType.NONE -EXTRACTIONS_KEY = "extractions" - - class BaseSchema(abc.ABC): """Abstract base class for generating structured constraints from examples.""" @@ -58,101 +57,115 @@ def from_examples( ) -> BaseSchema: """Factory method to build a schema instance from example data.""" + @abc.abstractmethod + def to_provider_config(self) -> dict[str, Any]: + """Convert schema to provider-specific configuration. -@dataclasses.dataclass -class GeminiSchema(BaseSchema): - """Schema implementation for Gemini structured output. + Returns: + Dictionary of provider kwargs (e.g., response_schema for Gemini). + Should be a pure data mapping with no side effects. + """ - Converts ExampleData objects into an OpenAPI/JSON-schema definition - that Gemini can interpret via 'response_schema'. - """ + @property + @abc.abstractmethod + def supports_strict_mode(self) -> bool: + """Whether the provider emits valid output without needing Markdown fences. - _schema_dict: dict + Returns: + True when the provider will emit syntactically valid JSON (or other + machine-parseable format) without needing Markdown fences. This says + nothing about attribute-level schema enforcement. False otherwise. + """ - @property - def schema_dict(self) -> dict: - """Returns the schema dictionary.""" - return self._schema_dict + def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None: + """Hook to update schema state based on provider kwargs. + + This allows schemas to adjust their behavior based on caller overrides. + For example, FormatModeSchema uses this to sync its format when the caller + overrides it, ensuring supports_strict_mode stays accurate. + + Default implementation does nothing. Override if your schema needs to + respond to provider kwargs. + + Args: + kwargs: The effective provider kwargs after merging. + """ + + +class FormatModeSchema(BaseSchema): + """Generic schema for providers that support format modes (JSON/YAML). + + This schema doesn't enforce structure, only output format. Useful for + providers that can guarantee syntactically valid JSON or YAML but don't + support field-level constraints. + """ + + def __init__(self, format_mode: str = "json"): + """Initialize with a format mode. - @schema_dict.setter - def schema_dict(self, schema_dict: dict) -> None: - """Sets the schema dictionary.""" - self._schema_dict = schema_dict + Args: + format_mode: The output format ("json", "yaml", etc.). + """ + self._format = format_mode @classmethod def from_examples( cls, examples_data: Sequence[data.ExampleData], attribute_suffix: str = "_attributes", - ) -> GeminiSchema: - """Creates a GeminiSchema from example extractions. + ) -> FormatModeSchema: + """Create a FormatModeSchema instance. - Builds a JSON-based schema with a top-level "extractions" array. Each - element in that array is an object containing the extraction class name - and an accompanying "_attributes" object for its attributes. + Since format mode doesn't use examples for constraints, this + simply returns a JSON-mode instance. Args: - examples_data: A sequence of ExampleData objects containing extraction - classes and attributes. - attribute_suffix: String appended to each class name to form the - attributes field name (defaults to "_attributes"). + examples_data: Ignored (kept for interface compatibility). + attribute_suffix: Ignored (kept for interface compatibility). Returns: - A GeminiSchema with internal dictionary represents the JSON constraint. + A FormatModeSchema configured for JSON output. """ - # Track attribute types for each category - extraction_categories: dict[str, dict[str, set[type]]] = {} - for example in examples_data: - for extraction in example.extractions: - category = extraction.extraction_class - if category not in extraction_categories: - extraction_categories[category] = {} - - if extraction.attributes: - for attr_name, attr_value in extraction.attributes.items(): - if attr_name not in extraction_categories[category]: - extraction_categories[category][attr_name] = set() - extraction_categories[category][attr_name].add(type(attr_value)) - - extraction_properties: dict[str, dict[str, Any]] = {} - - for category, attrs in extraction_categories.items(): - extraction_properties[category] = {"type": "string"} - - attributes_field = f"{category}{attribute_suffix}" - attr_properties = {} - - # If no attributes were found for this category, add a default property. - if not attrs: - attr_properties["_unused"] = {"type": "string"} - else: - for attr_name, attr_types in attrs.items(): - # If we see list type, use array of strings - if list in attr_types: - attr_properties[attr_name] = { - "type": "array", - "items": {"type": "string"}, - } - else: - attr_properties[attr_name] = {"type": "string"} - - extraction_properties[attributes_field] = { - "type": "object", - "properties": attr_properties, - "nullable": True, - } - - extraction_schema = { - "type": "object", - "properties": extraction_properties, - } - - schema_dict = { - "type": "object", - "properties": { - EXTRACTIONS_KEY: {"type": "array", "items": extraction_schema} - }, - "required": [EXTRACTIONS_KEY], - } - - return cls(_schema_dict=schema_dict) + return cls(format_mode="json") + + def to_provider_config(self) -> dict[str, Any]: + """Convert to provider configuration. + + Returns: + Dictionary with format parameter. + """ + return {"format": self._format} + + @property + def supports_strict_mode(self) -> bool: + """Whether the format guarantees valid output without fences. + + Returns: + True for JSON (guaranteed valid syntax), False for YAML/others. + """ + return self._format == "json" + + def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None: + """Update format based on provider kwargs. + + Args: + kwargs: The effective provider kwargs after merging. + """ + if "format" in kwargs: + self._format = kwargs["format"] + + +# TODO(v2.0.0): Remove GeminiSchema re-export +# pylint: disable=wrong-import-position,cyclic-import +from langextract.providers.schemas import gemini as gemini_schema + +GeminiSchema = gemini_schema.GeminiSchema + +__all__ = [ + "BaseSchema", + "FormatModeSchema", + "Constraint", + "ConstraintType", + "GeminiSchema", + "EXTRACTIONS_KEY", +] diff --git a/pyproject.toml b/pyproject.toml index 3831c844..53ccff96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ notebook = [ ] [tool.setuptools] -packages = ["langextract", "langextract.providers"] +packages = ["langextract", "langextract.providers", "langextract.providers.schemas"] include-package-data = false [tool.setuptools.exclude-package-data] diff --git a/tests/extract_precedence_test.py b/tests/extract_precedence_test.py index fc40613c..26c69853 100644 --- a/tests/extract_precedence_test.py +++ b/tests/extract_precedence_test.py @@ -47,6 +47,7 @@ def setUp(self): def test_model_overrides_all_other_parameters( self, mock_create_model, mock_annotator_cls ): + """Test that model parameter overrides all other model-related parameters.""" provided_model = mock.MagicMock() mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" @@ -75,10 +76,12 @@ def test_model_overrides_all_other_parameters( def test_config_overrides_model_id_and_language_model_type( self, mock_create_model, mock_annotator_cls ): + """Test that config parameter overrides model_id and language_model_type.""" config = factory.ModelConfig( model_id="config-model", provider_kwargs={"api_key": "config-key"} ) mock_model = mock.MagicMock() + mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" @@ -96,10 +99,13 @@ def test_config_overrides_model_id_and_language_model_type( ) mock_model_config.assert_not_called() - mock_create_model.assert_called_once_with(config) + mock_create_model.assert_called_once() + called_config = mock_create_model.call_args[1]["config"] + self.assertEqual(called_config.model_id, "config-model") + self.assertEqual(called_config.provider_kwargs, {"api_key": "config-key"}) + _, kwargs = mock_annotator_cls.call_args self.assertIs(kwargs["language_model"], mock_model) - self.assertEqual(config.provider_kwargs, {"api_key": "config-key"}) self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @@ -107,7 +113,9 @@ def test_config_overrides_model_id_and_language_model_type( def test_model_id_and_base_kwargs_override_language_model_type( self, mock_create_model, mock_annotator_cls ): + """Test that model_id and other kwargs are used when no model or config.""" mock_model = mock.MagicMock() + mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator_cls.return_value.annotate_text.return_value = "ok" mock_config = mock.MagicMock() @@ -136,7 +144,7 @@ def test_model_id_and_base_kwargs_override_language_model_type( self.assertEqual(provider_kwargs["temperature"], 0.9) self.assertEqual(provider_kwargs["model_url"], "http://model") self.assertEqual(provider_kwargs["base_url"], "http://model") - mock_create_model.assert_called_once_with(mock_config) + mock_create_model.assert_called_once() self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @@ -144,7 +152,9 @@ def test_model_id_and_base_kwargs_override_language_model_type( def test_language_model_type_only_emits_warning_and_works( self, mock_create_model, mock_annotator_cls ): + """Test that language_model_type emits deprecation warning but still works.""" mock_model = mock.MagicMock() + mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator_cls.return_value.annotate_text.return_value = "ok" mock_config = mock.MagicMock() @@ -164,7 +174,7 @@ def test_language_model_type_only_emits_warning_and_works( mock_model_config.assert_called_once() _, kwargs = mock_model_config.call_args self.assertEqual(kwargs["model_id"], "gemini-2.5-flash") - mock_create_model.assert_called_once_with(mock_config) + mock_create_model.assert_called_once() self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @@ -178,6 +188,7 @@ def test_use_schema_constraints_warns_with_config( ) mock_model = mock.MagicMock() + mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" @@ -191,12 +202,11 @@ def test_use_schema_constraints_warns_with_config( use_schema_constraints=True, ) - self.assertIn("use_schema_constraints", str(cm.warning)) - self.assertIn("ignored", str(cm.warning)) + self.assertIn("schema constraints", str(cm.warning)) + self.assertIn("applied", str(cm.warning)) mock_create_model.assert_called_once() - called_config = mock_create_model.call_args[0][0] + called_config = mock_create_model.call_args[1]["config"] self.assertEqual(called_config.model_id, "gemini-2.5-flash") - self.assertNotIn("gemini_schema", called_config.provider_kwargs) self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") diff --git a/tests/extract_schema_integration_test.py b/tests/extract_schema_integration_test.py new file mode 100644 index 00000000..b6ffa49f --- /dev/null +++ b/tests/extract_schema_integration_test.py @@ -0,0 +1,216 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for extract function with new schema system.""" + +from unittest import mock +import warnings + +from absl.testing import absltest + +from langextract import data +import langextract as lx + + +class ExtractSchemaIntegrationTest(absltest.TestCase): + """Tests for extract function with schema system integration.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.examples = [ + data.ExampleData( + text="Patient has diabetes", + extractions=[ + data.Extraction( + extraction_class="condition", + extraction_text="diabetes", + attributes={"severity": "moderate"}, + ) + ], + ) + ] + self.test_text = "Patient has hypertension" + + @mock.patch.dict("os.environ", {"GEMINI_API_KEY": "test_key"}) + def test_extract_with_gemini_uses_schema(self): + """Test that extract with Gemini automatically uses schema.""" + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ) as mock_init: + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.infer", + return_value=iter([[mock.Mock(output='{"extractions": []}')]]), + ): + with mock.patch( + "langextract.annotation.Annotator.annotate_text", + return_value=data.AnnotatedDocument( + text=self.test_text, extractions=[] + ), + ): + result = lx.extract( + text_or_documents=self.test_text, + prompt_description="Extract conditions", + examples=self.examples, + model_id="gemini-2.5-flash", + use_schema_constraints=True, + fence_output=None, # Let it compute + ) + + # Should have been called with response_schema + call_kwargs = mock_init.call_args[1] + self.assertIn("response_schema", call_kwargs) + + # Result should be an AnnotatedDocument + self.assertIsInstance(result, data.AnnotatedDocument) + + @mock.patch.dict("os.environ", {"OLLAMA_BASE_URL": "http://localhost:11434"}) + def test_extract_with_ollama_uses_json_mode(self): + """Test that extract with Ollama uses JSON mode.""" + with mock.patch( + "langextract.providers.ollama.OllamaLanguageModel.__init__", + return_value=None, + ) as mock_init: + with mock.patch( + "langextract.providers.ollama.OllamaLanguageModel.infer", + return_value=iter([[mock.Mock(output='{"extractions": []}')]]), + ): + with mock.patch( + "langextract.annotation.Annotator.annotate_text", + return_value=data.AnnotatedDocument( + text=self.test_text, extractions=[] + ), + ): + result = lx.extract( + text_or_documents=self.test_text, + prompt_description="Extract conditions", + examples=self.examples, + model_id="gemma2:2b", + use_schema_constraints=True, + fence_output=None, # Let it compute + ) + + # Should have been called with format="json" + call_kwargs = mock_init.call_args[1] + self.assertIn("format", call_kwargs) + self.assertEqual(call_kwargs["format"], "json") + + # Result should be an AnnotatedDocument + self.assertIsInstance(result, data.AnnotatedDocument) + + def test_extract_explicit_fence_respected(self): + """Test that explicit fence_output is respected in extract.""" + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ): + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.infer", + return_value=iter([[mock.Mock(output='{"extractions": []}')]]), + ): + with mock.patch( + "langextract.annotation.Annotator.__init__", return_value=None + ) as mock_annotator_init: + with mock.patch( + "langextract.annotation.Annotator.annotate_text", + return_value=data.AnnotatedDocument( + text=self.test_text, extractions=[] + ), + ): + _ = lx.extract( + text_or_documents=self.test_text, + prompt_description="Extract conditions", + examples=self.examples, + model_id="gemini-2.5-flash", + api_key="test_key", + use_schema_constraints=True, + fence_output=True, # Explicitly set + ) + + # Annotator should be created with fence_output=True + call_kwargs = mock_annotator_init.call_args[1] + self.assertTrue(call_kwargs["fence_output"]) + + def test_extract_gemini_schema_deprecation_warning(self): + """Test that passing gemini_schema triggers deprecation warning.""" + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ): + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.infer", + return_value=iter([[mock.Mock(output='{"extractions": []}')]]), + ): + with mock.patch( + "langextract.annotation.Annotator.annotate_text", + return_value=data.AnnotatedDocument( + text=self.test_text, extractions=[] + ), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + _ = lx.extract( + text_or_documents=self.test_text, + prompt_description="Extract conditions", + examples=self.examples, + model_id="gemini-2.5-flash", + api_key="test_key", + language_model_params={ + "gemini_schema": "some_schema" + }, # Deprecated + ) + + # Should have triggered deprecation warning + deprecation_warnings = [ + warning + for warning in w + if issubclass(warning.category, DeprecationWarning) + and "gemini_schema" in str(warning.message) + ] + self.assertGreater(len(deprecation_warnings), 0) + + def test_extract_no_schema_when_disabled(self): + """Test that no schema is used when use_schema_constraints=False.""" + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ) as mock_init: + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.infer", + return_value=iter([[mock.Mock(output='{"extractions": []}')]]), + ): + with mock.patch( + "langextract.annotation.Annotator.annotate_text", + return_value=data.AnnotatedDocument( + text=self.test_text, extractions=[] + ), + ): + _ = lx.extract( + text_or_documents=self.test_text, + prompt_description="Extract conditions", + examples=self.examples, + model_id="gemini-2.5-flash", + api_key="test_key", + use_schema_constraints=False, # Disabled + ) + + # Should NOT have response_schema + call_kwargs = mock_init.call_args[1] + self.assertNotIn("response_schema", call_kwargs) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/factory_schema_test.py b/tests/factory_schema_test.py new file mode 100644 index 00000000..64339066 --- /dev/null +++ b/tests/factory_schema_test.py @@ -0,0 +1,258 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for factory schema integration and fence defaulting.""" + +from unittest import mock + +from absl.testing import absltest + +from langextract import data +from langextract import factory +from langextract import inference +from langextract import schema + + +class FactorySchemaIntegrationTest(absltest.TestCase): + """Tests for create_model_with_schema factory function.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.examples = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + ) + ], + ) + ] + + @mock.patch.dict("os.environ", {"GEMINI_API_KEY": "test_key"}) + def test_gemini_with_schema_returns_false_fence(self): + """Test that Gemini with schema returns fence_output=False.""" + config = factory.ModelConfig( + model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} + ) + + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ) as mock_init: + model = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=True, + fence_output=None, # Let it compute default + ) + + # Should have called init with response_schema in kwargs + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + self.assertIn("response_schema", call_kwargs) + + # Fence should be False for strict schema + self.assertFalse(model.requires_fence_output) + + @mock.patch.dict("os.environ", {"OLLAMA_BASE_URL": "http://localhost:11434"}) + def test_ollama_with_schema_returns_false_fence(self): + """Test that Ollama with JSON mode returns fence_output=False.""" + config = factory.ModelConfig(model_id="gemma2:2b") + + with mock.patch( + "langextract.providers.ollama.OllamaLanguageModel.__init__", + return_value=None, + ) as mock_init: + model = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=True, + fence_output=None, # Let it compute default + ) + + # Should have called init with format in kwargs + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + self.assertIn("format", call_kwargs) + self.assertEqual(call_kwargs["format"], "json") + + # Fence should be False since Ollama JSON mode outputs valid JSON + self.assertFalse(model.requires_fence_output) + + def test_explicit_fence_output_respected(self): + """Test that explicit fence_output is not overridden.""" + config = factory.ModelConfig( + model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} + ) + + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ): + # Explicitly set fence to True (opposite of default for Gemini) + model = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=True, + fence_output=True, # Explicit value + ) + + # Should respect explicit value + self.assertTrue(model.requires_fence_output) + + def test_no_schema_defaults_to_true_fence(self): + """Test that models without schema support default to fence_output=True.""" + + class NoSchemaModel(inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + yield [] + + config = factory.ModelConfig(model_id="test-model") + + with mock.patch( + "langextract.providers.registry.resolve", return_value=NoSchemaModel + ): + with mock.patch.object(NoSchemaModel, "__init__", return_value=None): + model = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=True, + fence_output=None, + ) + + # Should default to True for backward compatibility + self.assertTrue(model.requires_fence_output) + + def test_schema_disabled_returns_true_fence(self): + """Test that disabling schema constraints returns fence_output=True.""" + config = factory.ModelConfig( + model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} + ) + + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ) as mock_init: + model = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=False, # Disabled + fence_output=None, + ) + + # Should not have response_schema in kwargs + call_kwargs = mock_init.call_args[1] + self.assertNotIn("response_schema", call_kwargs) + + # Should default to True when no schema + self.assertTrue(model.requires_fence_output) + + def test_caller_overrides_schema_config(self): + """Test that caller's provider_kwargs override schema configuration.""" + # Use Ollama which normally sets format=json + config = factory.ModelConfig( + model_id="gemma2:2b", + provider_kwargs={"format": "yaml"}, # Caller wants YAML + ) + + with mock.patch( + "langextract.providers.ollama.OllamaLanguageModel.__init__", + return_value=None, + ) as mock_init: + _ = factory._create_model_with_schema( + config=config, + examples=self.examples, + use_schema_constraints=True, + fence_output=None, + ) + + # Should have called init with caller's YAML override + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + self.assertIn("format", call_kwargs) + self.assertEqual(call_kwargs["format"], "yaml") # Caller wins! + + def test_no_examples_no_schema(self): + """Test that no examples means no schema is created.""" + config = factory.ModelConfig( + model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} + ) + + with mock.patch( + "langextract.providers.gemini.GeminiLanguageModel.__init__", + return_value=None, + ) as mock_init: + model = factory._create_model_with_schema( + config=config, + examples=None, + use_schema_constraints=True, + fence_output=None, + ) + + # Should not have response_schema in kwargs + call_kwargs = mock_init.call_args[1] + self.assertNotIn("response_schema", call_kwargs) + + # Should default to True when no schema + self.assertTrue(model.requires_fence_output) + + +class SchemaApplicationTest(absltest.TestCase): + """Tests for apply_schema being called on models.""" + + def test_apply_schema_called_when_supported(self): + """Test that apply_schema is called on models that support it.""" + examples = [ + data.ExampleData( + text="Test", + extractions=[ + data.Extraction(extraction_class="test", extraction_text="test") + ], + ) + ] + + class SchemaAwareModel(inference.BaseLanguageModel): + + @classmethod + def get_schema_class(cls): + return schema.GeminiSchema + + def infer(self, batch_prompts, **kwargs): + yield [] + + config = factory.ModelConfig(model_id="test-model") + + with mock.patch( + "langextract.providers.registry.resolve", return_value=SchemaAwareModel + ): + with mock.patch.object(SchemaAwareModel, "__init__", return_value=None): + with mock.patch.object(SchemaAwareModel, "apply_schema") as mock_apply: + _ = factory._create_model_with_schema( + config=config, + examples=examples, + use_schema_constraints=True, + ) + + # apply_schema should have been called with the schema instance + mock_apply.assert_called_once() + schema_arg = mock_apply.call_args[0][0] + self.assertIsInstance(schema_arg, schema.GeminiSchema) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/init_test.py b/tests/init_test.py index b40f0e1d..b160b67b 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -37,7 +37,6 @@ def test_lang_extract_as_lx_extract( input_text = "Patient takes Aspirin 100mg every morning." - # Create a mock model instance mock_model = mock.MagicMock() mock_model.infer.return_value = [[ inference.ScoredOutput( @@ -65,10 +64,10 @@ def test_lang_extract_as_lx_extract( ) ]] - # Make factory return our mock model + mock_model.requires_fence_output = True mock_create_model.return_value = mock_model - mock_gemini_schema.return_value = None # No live endpoint to process schema + mock_gemini_schema.return_value = None expected_result = data.AnnotatedDocument( document_id=None, @@ -137,10 +136,57 @@ def test_lang_extract_as_lx_extract( mock_create_model.assert_called_once() mock_model.infer.assert_called_once_with( batch_prompts=[prompt_generator.render(input_text)], + max_workers=10, # Default value from extract() ) self.assertDataclassEqual(expected_result, actual_result) + @mock.patch.object(schema.GeminiSchema, "from_examples", autospec=True) + @mock.patch("langextract.factory.create_model") + def test_extract_custom_params_reach_inference( + self, mock_create_model, mock_gemini_schema + ): + """Sanity check that custom parameters reach the inference layer.""" + input_text = "Test text" + + mock_model = mock.MagicMock() + mock_model.infer.return_value = [[ + inference.ScoredOutput( + output='```json\n{"extractions": []}\n```', + score=0.9, + ) + ]] + + mock_model.requires_fence_output = True + mock_create_model.return_value = mock_model + mock_gemini_schema.return_value = None + + mock_examples = [ + lx.data.ExampleData( + text="Example", + extractions=[ + lx.data.Extraction( + extraction_class="test", + extraction_text="example", + ), + ], + ) + ] + + lx.extract( + text_or_documents=input_text, + prompt_description="Test extraction", + examples=mock_examples, + api_key="test_key", + max_workers=5, + fence_output=True, + use_schema_constraints=False, + ) + + mock_model.infer.assert_called_once() + _, kwargs = mock_model.infer.call_args + self.assertEqual(kwargs.get("max_workers"), 5) + if __name__ == "__main__": absltest.main() diff --git a/tests/progress_test.py b/tests/progress_test.py index bfe266d1..6d6bb83c 100644 --- a/tests/progress_test.py +++ b/tests/progress_test.py @@ -56,12 +56,10 @@ def test_save_load_progress_bars(self): def test_model_info_extraction(self): """Test extracting model info from objects.""" - # Test with model_id mock_model = mock.MagicMock() mock_model.model_id = "gemini-1.5-pro" self.assertEqual(progress.get_model_info(mock_model), "gemini-1.5-pro") - # Test with no attributes mock_model = mock.MagicMock() del mock_model.model_id del mock_model.model_url diff --git a/tests/provider_plugin_test.py b/tests/provider_plugin_test.py index db6bed9c..483d8af7 100644 --- a/tests/provider_plugin_test.py +++ b/tests/provider_plugin_test.py @@ -261,6 +261,105 @@ def infer(self, batch_prompts, **kwargs): cls_by_partial = lx.providers.registry.resolve_provider("resolveme") self.assertEqual(cls_by_partial.__name__, "ResolveMePlease") + def test_plugin_with_custom_schema(self): + """Test that a plugin can provide its own schema implementation.""" + + class TestPluginSchema(lx.schema.BaseSchema): + """Test schema implementation.""" + + def __init__(self, config): + self._config = config + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + return cls({"generated": True, "count": len(examples_data)}) + + def to_provider_config(self): + return {"custom_schema": self._config} + + @property + def supports_strict_mode(self): + return True + + def _ep_load(): + @lx.providers.registry.register(r"^custom-schema-test") + class SchemaTestProvider(lx.inference.BaseLanguageModel): + + def __init__(self, model_id=None, **kwargs): + super().__init__() + self.model_id = model_id + self.schema_config = kwargs.get("custom_schema") + + @classmethod + def get_schema_class(cls): + return TestPluginSchema + + def infer(self, batch_prompts, **kwargs): + output = ( + f"Schema={self.schema_config}" + if self.schema_config + else "No schema" + ) + return [[lx.inference.ScoredOutput(score=1.0, output=output)]] + + return SchemaTestProvider + + ep = types.SimpleNamespace( + name="schema_test", + group="langextract.providers", + value="test:SchemaTestProvider", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + provider_cls = lx.providers.registry.resolve("custom-schema-test-v1") + self.assertEqual( + provider_cls.get_schema_class().__name__, + "TestPluginSchema", + "Plugin should provide custom schema class", + ) + + examples = [ + lx.data.ExampleData( + text="Test", + extractions=[ + lx.data.Extraction( + extraction_class="test", + extraction_text="test text", + ) + ], + ) + ] + + config = lx.factory.ModelConfig(model_id="custom-schema-test-v1") + model = lx.factory._create_model_with_schema( + config=config, + examples=examples, + use_schema_constraints=True, + fence_output=None, + ) + + self.assertIsNotNone( + model.schema_config, + "Model should have schema config applied", + ) + self.assertTrue( + model.schema_config["generated"], + "Schema should be generated from examples", + ) + self.assertFalse( + model.requires_fence_output, + "Schema supports strict mode, no fences needed", + ) + class PluginE2ETest(absltest.TestCase): """End-to-end test with actual pip installation. @@ -269,6 +368,108 @@ class PluginE2ETest(absltest.TestCase): via tox -e plugin-e2e or in CI when provider files change. """ + def test_plugin_with_schema_e2e(self): + """Test that a plugin with custom schema works end-to-end with extract().""" + + class TestPluginSchema(lx.schema.BaseSchema): + """Test schema implementation.""" + + def __init__(self, config): + self._config = config + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + return cls({"generated": True, "count": len(examples_data)}) + + def to_provider_config(self): + return {"custom_schema": self._config} + + @property + def supports_strict_mode(self): + return True + + def _ep_load(): + @lx.providers.registry.register(r"^e2e-schema-test") + class SchemaE2EProvider(lx.inference.BaseLanguageModel): + + def __init__(self, model_id=None, **kwargs): + super().__init__() + self.model_id = model_id + self.schema_config = kwargs.get("custom_schema") + + @classmethod + def get_schema_class(cls): + return TestPluginSchema + + def infer(self, batch_prompts, **kwargs): + # Return a mock extraction that includes schema info + if self.schema_config: + output = ( + '{"extractions": [{"entity": "test", ' + '"entity_attributes": {"schema": "applied"}}]}' + ) + else: + output = '{"extractions": []}' + return [[lx.inference.ScoredOutput(score=1.0, output=output)]] + + return SchemaE2EProvider + + ep = types.SimpleNamespace( + name="schema_e2e", + group="langextract.providers", + value="test:SchemaE2EProvider", + load=_ep_load, + ) + + # Clear and set up registry + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + self.addCleanup(lx.providers.registry.clear) + self.addCleanup(setattr, lx.providers, "_PLUGINS_LOADED", False) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + # Test with extract() using schema constraints + examples = [ + lx.data.ExampleData( + text="Find entities", + extractions=[ + lx.data.Extraction( + extraction_class="entity", + extraction_text="example", + attributes={"type": "test"}, + ) + ], + ) + ] + + result = lx.extract( + text_or_documents="Test text for extraction", + prompt_description="Extract entities", + examples=examples, + model_id="e2e-schema-test-v1", + use_schema_constraints=True, + fence_output=False, # Schema supports strict mode + ) + + # Verify we got results + self.assertIsInstance(result, lx.data.AnnotatedDocument) + self.assertIsNotNone(result.extractions) + self.assertGreater(len(result.extractions), 0) + + # Verify the schema was applied by checking the extraction + extraction = result.extractions[0] + self.assertEqual(extraction.extraction_class, "entity") + self.assertIn("schema", extraction.attributes) + self.assertEqual(extraction.attributes["schema"], "applied") + @pytest.mark.requires_pip @pytest.mark.integration def test_pip_install_discovery_and_cleanup(self): @@ -294,16 +495,39 @@ def test_pip_install_discovery_and_cleanup(self): USED_BY_EXTRACT = False + class TestPipSchema(lx.schema.BaseSchema): + '''Test schema for pip provider.''' + + def __init__(self, config): + self._config = config + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + return cls({"pip_schema": True, "examples": len(examples_data)}) + + def to_provider_config(self): + return {"schema_config": self._config} + + @property + def supports_strict_mode(self): + return True + @lx.providers.registry.register(r'^test-pip-model', priority=50) class TestPipProvider(lx.inference.BaseLanguageModel): def __init__(self, model_id, **kwargs): super().__init__() self.model_id = model_id + self.schema_config = kwargs.get("schema_config", {}) + + @classmethod + def get_schema_class(cls): + return TestPipSchema def infer(self, batch_prompts, **kwargs): global USED_BY_EXTRACT USED_BY_EXTRACT = True - return [[lx.inference.ScoredOutput(score=1.0, output="pip test response")]] + schema_info = "with_schema" if self.schema_config else "no_schema" + return [[lx.inference.ScoredOutput(score=1.0, output=f"pip test response: {schema_info}")]] """)) (pkg_dir / "pyproject.toml").write_text(textwrap.dedent(f""" @@ -357,20 +581,46 @@ def infer(self, batch_prompts, **kwargs): lx.providers.load_plugins_once() - # Test via factory.create_model + # Test 1: Basic usage without schema cfg = lx.factory.ModelConfig(model_id="test-pip-model-123") model = lx.factory.create_model(cfg) result = model.infer(["test prompt"]) - assert result[0][0].output == "pip test response", f"Got: {{result[0][0].output}}" - - # Verify the plugin is resolvable via the registry - resolved = lx.providers.registry.resolve("test-pip-model-xyz") - assert resolved.__name__ == "TestPipProvider", "Plugin should be resolvable" + assert "no_schema" in result[0][0].output, f"Got: {{result[0][0].output}}" + + # Test 2: With schema constraints + examples = [ + lx.data.ExampleData( + text="test", + extractions=[ + lx.data.Extraction( + extraction_class="test", + extraction_text="test", + ) + ], + ) + ] + + cfg2 = lx.factory.ModelConfig(model_id="test-pip-model-456") + model2 = lx.factory._create_model_with_schema( + config=cfg2, + examples=examples, + use_schema_constraints=True, + fence_output=None, + ) + result2 = model2.infer(["test prompt"]) + assert "with_schema" in result2[0][0].output, f"Got: {{result2[0][0].output}}" + assert model2.requires_fence_output == False, "Schema supports strict mode, should not need fences" + + # Test 3: Verify schema class is available + provider_cls = lx.providers.registry.resolve("test-pip-model-xyz") + assert provider_cls.__name__ == "TestPipProvider", "Plugin should be resolvable" + schema_cls = provider_cls.get_schema_class() + assert schema_cls.__name__ == "TestPipSchema", f"Schema class should be TestPipSchema, got {{schema_cls.__name__}}" from {pkg_name}.provider import USED_BY_EXTRACT assert USED_BY_EXTRACT, "Provider infer() was not called" - print("SUCCESS: Plugin test passed") + print("SUCCESS: Plugin test with schema passed") """)) result = subprocess.run( diff --git a/tests/provider_schema_test.py b/tests/provider_schema_test.py new file mode 100644 index 00000000..27423338 --- /dev/null +++ b/tests/provider_schema_test.py @@ -0,0 +1,479 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for provider schema discovery and implementations.""" + +from unittest import mock + +from absl.testing import absltest + +from langextract import data +from langextract import exceptions +from langextract import factory +from langextract import schema +from langextract.providers import gemini as gemini_provider +from langextract.providers import ollama +from langextract.providers import openai + + +class ProviderSchemaDiscoveryTest(absltest.TestCase): + """Tests for provider schema discovery via get_schema_class().""" + + def test_gemini_returns_gemini_schema(self): + """Test that GeminiLanguageModel returns GeminiSchema.""" + schema_class = gemini_provider.GeminiLanguageModel.get_schema_class() + self.assertEqual( + schema_class, + schema.GeminiSchema, + msg="GeminiLanguageModel should return GeminiSchema class", + ) + + def test_ollama_returns_format_mode_schema(self): + """Test that OllamaLanguageModel returns FormatModeSchema.""" + schema_class = ollama.OllamaLanguageModel.get_schema_class() + self.assertEqual( + schema_class, + schema.FormatModeSchema, + msg="OllamaLanguageModel should return FormatModeSchema class", + ) + + def test_openai_returns_none(self): + """Test that OpenAILanguageModel returns None (no schema support yet).""" + # OpenAI imports dependencies in __init__, not at module level + schema_class = openai.OpenAILanguageModel.get_schema_class() + self.assertIsNone( + schema_class, + msg="OpenAILanguageModel should return None (no schema support)", + ) + + +class FormatModeSchemaTest(absltest.TestCase): + """Tests for FormatModeSchema implementation.""" + + def test_from_examples_ignores_examples(self): + """Test that FormatModeSchema ignores examples and returns JSON mode.""" + examples_data = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + attributes={"key": "value"}, + ) + ], + ) + ] + + test_schema = schema.FormatModeSchema.from_examples(examples_data) + self.assertEqual( + test_schema._format, + "json", + msg="FormatModeSchema should default to JSON format", + ) + + def test_to_provider_config_returns_format(self): + """Test that to_provider_config returns format parameter.""" + examples_data = [] + test_schema = schema.FormatModeSchema.from_examples(examples_data) + + provider_config = test_schema.to_provider_config() + + self.assertEqual( + provider_config, + {"format": "json"}, + msg="Provider config should contain format: json", + ) + + def test_supports_strict_mode_returns_true(self): + """Test that FormatModeSchema supports strict mode (valid JSON output).""" + examples_data = [] + test_schema = schema.FormatModeSchema.from_examples(examples_data) + + self.assertTrue( + test_schema.supports_strict_mode, + msg="FormatModeSchema should support strict mode", + ) + + def test_different_examples_same_output(self): + """Test that different examples produce the same schema for Ollama.""" + examples1 = [ + data.ExampleData( + text="Text 1", + extractions=[ + data.Extraction( + extraction_class="class1", extraction_text="text1" + ) + ], + ) + ] + + examples2 = [ + data.ExampleData( + text="Text 2", + extractions=[ + data.Extraction( + extraction_class="class2", + extraction_text="text2", + attributes={"attr": "value"}, + ) + ], + ) + ] + + schema1 = schema.FormatModeSchema.from_examples(examples1) + schema2 = schema.FormatModeSchema.from_examples(examples2) + + # Examples are ignored by FormatModeSchema + self.assertEqual( + schema1.to_provider_config(), + schema2.to_provider_config(), + msg="Different examples should produce same config for Ollama", + ) + + +class OllamaYAMLOverrideTest(absltest.TestCase): + """Tests for Ollama YAML format override behavior.""" + + def test_ollama_yaml_format_in_request_payload(self): + """Test that YAML format override appears in Ollama request payload.""" + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = {"response": '{"extractions": []}'} + mock_post.return_value = mock_response + + model = ollama.OllamaLanguageModel(model_id="gemma2:2b", format="yaml") + + list(model.infer(["Test prompt"])) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + self.assertIn( + "json", call_kwargs, msg="Request should use json parameter" + ) + payload = call_kwargs["json"] + self.assertIn("format", payload, msg="Payload should contain format key") + self.assertEqual(payload["format"], "yaml", msg="Format should be yaml") + + def test_yaml_override_sets_fence_output_true(self): + """Test that overriding to YAML format sets fence_output to True.""" + + examples_data = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + ) + ], + ) + ] + + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = {"response": '{"extractions": []}'} + mock_post.return_value = mock_response + + with mock.patch("langextract.providers.registry.resolve") as mock_resolve: + mock_resolve.return_value = ollama.OllamaLanguageModel + + config = factory.ModelConfig( + model_id="gemma2:2b", + provider_kwargs={"format": "yaml"}, + ) + + model = factory.create_model( + config=config, + examples=examples_data, + use_schema_constraints=True, + fence_output=None, # Let it be computed + ) + + self.assertTrue( + model.requires_fence_output, msg="YAML format should require fences" + ) + + def test_json_format_keeps_fence_output_false(self): + """Test that JSON format keeps fence_output False.""" + + examples_data = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + ) + ], + ) + ] + + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = {"response": '{"extractions": []}'} + mock_post.return_value = mock_response + + with mock.patch("langextract.providers.registry.resolve") as mock_resolve: + mock_resolve.return_value = ollama.OllamaLanguageModel + + config = factory.ModelConfig( + model_id="gemma2:2b", + provider_kwargs={"format": "json"}, + ) + + model = factory.create_model( + config=config, + examples=examples_data, + use_schema_constraints=True, + fence_output=None, # Let it be computed + ) + + self.assertFalse( + model.requires_fence_output, + msg="JSON format should not require fences", + ) + + +class GeminiSchemaProviderIntegrationTest(absltest.TestCase): + """Tests for GeminiSchema provider integration.""" + + def test_gemini_schema_to_provider_config(self): + """Test that GeminiSchema.to_provider_config includes response_schema.""" + examples_data = [ + data.ExampleData( + text="Patient has diabetes", + extractions=[ + data.Extraction( + extraction_class="condition", + extraction_text="diabetes", + attributes={"severity": "moderate"}, + ) + ], + ) + ] + + gemini_schema = schema.GeminiSchema.from_examples(examples_data) + provider_config = gemini_schema.to_provider_config() + + self.assertIn( + "response_schema", + provider_config, + msg="GeminiSchema config should contain response_schema", + ) + self.assertIsInstance( + provider_config["response_schema"], + dict, + msg="response_schema should be a dictionary", + ) + self.assertIn( + "properties", + provider_config["response_schema"], + msg="response_schema should contain properties field", + ) + + self.assertIn( + "response_mime_type", + provider_config, + msg="GeminiSchema config should contain response_mime_type", + ) + self.assertEqual( + provider_config["response_mime_type"], + "application/json", + msg="response_mime_type should be application/json", + ) + + def test_gemini_supports_strict_mode(self): + """Test that GeminiSchema supports strict mode.""" + examples_data = [] + gemini_schema = schema.GeminiSchema.from_examples(examples_data) + self.assertTrue( + gemini_schema.supports_strict_mode, + msg="GeminiSchema should support strict mode", + ) + + def test_gemini_rejects_yaml_with_schema(self): + """Test that Gemini raises error when YAML format is used with schema.""" + + examples_data = [ + data.ExampleData( + text="Test", + extractions=[ + data.Extraction( + extraction_class="test", + extraction_text="test text", + ) + ], + ) + ] + test_schema = schema.GeminiSchema.from_examples(examples_data) + + with mock.patch("google.genai.Client", autospec=True): + model = gemini_provider.GeminiLanguageModel( + model_id="gemini-2.5-flash", + api_key="test_key", + format_type=data.FormatType.YAML, + ) + model.apply_schema(test_schema) + + prompt = "Test prompt" + config = {"temperature": 0.5} + with self.assertRaises(exceptions.InferenceRuntimeError) as cm: + _ = model._process_single_prompt(prompt, config) + + self.assertIn( + "only supports JSON format", + str(cm.exception), + msg="Error should mention JSON-only constraint", + ) + + def test_gemini_forwards_schema_to_genai_client(self): + """Test that GeminiLanguageModel forwards schema config to genai client.""" + + examples_data = [ + data.ExampleData( + text="Test", + extractions=[ + data.Extraction( + extraction_class="test", + extraction_text="test text", + ) + ], + ) + ] + test_schema = schema.GeminiSchema.from_examples(examples_data) + + with mock.patch("google.genai.Client", autospec=True) as mock_client: + mock_model_instance = mock.Mock(spec=["return_value"]) + mock_client.return_value.models.generate_content = mock_model_instance + mock_model_instance.return_value.text = '{"extractions": []}' + + model = gemini_provider.GeminiLanguageModel( + model_id="gemini-2.5-flash", + api_key="test_key", + response_schema=test_schema.schema_dict, + response_mime_type="application/json", + ) + + prompt = "Test prompt" + config = {"temperature": 0.5} + _ = model._process_single_prompt(prompt, config) + + mock_model_instance.assert_called_once() + call_kwargs = mock_model_instance.call_args[1] + self.assertIn( + "config", + call_kwargs, + msg="genai.generate_content should receive config parameter", + ) + self.assertIn( + "response_schema", + call_kwargs["config"], + msg="Config should contain response_schema from GeminiSchema", + ) + self.assertIn( + "response_mime_type", + call_kwargs["config"], + msg="Config should contain response_mime_type", + ) + self.assertEqual( + call_kwargs["config"]["response_mime_type"], + "application/json", + msg="response_mime_type should be application/json", + ) + + def test_gemini_doesnt_forward_non_api_kwargs(self): + """Test that GeminiLanguageModel doesn't forward non-API kwargs to genai.""" + + with mock.patch("google.genai.Client", autospec=True) as mock_client: + mock_model_instance = mock.Mock(spec=["return_value"]) + mock_client.return_value.models.generate_content = mock_model_instance + mock_model_instance.return_value.text = '{"extractions": []}' + + model = gemini_provider.GeminiLanguageModel( + model_id="gemini-2.5-flash", + api_key="test_key", + max_workers=5, + response_schema={"test": "schema"}, # API parameter + ) + + prompt = "Test prompt" + config = {"temperature": 0.5} + _ = model._process_single_prompt(prompt, config) + + mock_model_instance.assert_called_once() + call_kwargs = mock_model_instance.call_args[1] + + self.assertNotIn( + "max_workers", + call_kwargs["config"], + msg="max_workers should not be forwarded to genai API config", + ) + + self.assertIn( + "response_schema", + call_kwargs["config"], + msg="response_schema should be forwarded to genai API config", + ) + + +class SchemaShimTest(absltest.TestCase): + """Tests for backward compatibility shims in schema module.""" + + def test_extractions_key_import(self): + """Test that EXTRACTIONS_KEY can be imported from schema module.""" + from langextract import schema as s # pylint: disable=reimported,import-outside-toplevel + + self.assertEqual( + s.EXTRACTIONS_KEY, + "extractions", + msg="EXTRACTIONS_KEY should be 'extractions'", + ) + + def test_constraint_types_import(self): + """Test that Constraint and ConstraintType can be imported.""" + from langextract import schema as s # pylint: disable=reimported,import-outside-toplevel + + constraint = s.Constraint() + self.assertEqual( + constraint.constraint_type, + s.ConstraintType.NONE, + msg="Default Constraint should have type NONE", + ) + + self.assertEqual( + s.ConstraintType.NONE.value, + "none", + msg="ConstraintType.NONE should have value 'none'", + ) + + def test_provider_schema_imports(self): + """Test that provider schemas can be imported from schema module.""" + from langextract import schema as s # pylint: disable=reimported,import-outside-toplevel + + # Backward compatibility: re-exported from providers.schemas.gemini + self.assertTrue( + hasattr(s, "GeminiSchema"), + msg=( + "GeminiSchema should be importable from schema module for backward" + " compatibility" + ), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/registry_test.py b/tests/registry_test.py index 147a264c..b42e79b9 100644 --- a/tests/registry_test.py +++ b/tests/registry_test.py @@ -134,12 +134,10 @@ def test_list_entries(self): entries = registry.list_entries() self.assertEqual(len(entries), 2) - # Check first entry patterns1, priority1 = entries[0] self.assertEqual(patterns1, ["^test1"]) self.assertEqual(priority1, 5) - # Check second entry patterns2, priority2 = entries[1] self.assertEqual(set(patterns2), {"^test2", "^test3"}) self.assertEqual(priority2, 10) diff --git a/tests/schema_test.py b/tests/schema_test.py index d4b067b5..e70b72cd 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -20,9 +20,66 @@ from absl.testing import parameterized from langextract import data +from langextract import inference from langextract import schema +class BaseSchemaTest(absltest.TestCase): + """Tests for BaseSchema abstract class.""" + + def test_abstract_methods_required(self): + """Test that BaseSchema cannot be instantiated directly.""" + with self.assertRaises(TypeError): + schema.BaseSchema() # pylint: disable=abstract-class-instantiated + + def test_subclass_must_implement_all_methods(self): + """Test that subclasses must implement all abstract methods.""" + + class IncompleteSchema(schema.BaseSchema): + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + return cls() + + # Missing to_provider_config and supports_strict_mode + + with self.assertRaises(TypeError): + IncompleteSchema() # pylint: disable=abstract-class-instantiated + + +class BaseLanguageModelSchemaTest(absltest.TestCase): + """Tests for BaseLanguageModel schema methods.""" + + def test_get_schema_class_returns_none_by_default(self): + """Test that get_schema_class returns None by default.""" + + class TestModel(inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + yield [] + + self.assertIsNone(TestModel.get_schema_class()) + + def test_apply_schema_stores_instance(self): + """Test that apply_schema stores the schema instance.""" + + class TestModel(inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + yield [] + + model = TestModel() + + mock_schema = mock.Mock(spec=schema.BaseSchema) + + model.apply_schema(mock_schema) + + self.assertEqual(model._schema, mock_schema) + + model.apply_schema(None) + self.assertIsNone(model._schema) + + class GeminiSchemaTest(parameterized.TestCase): @parameterized.named_parameters( @@ -179,6 +236,46 @@ def test_from_examples_constructs_expected_schema( actual_schema = gemini_schema.schema_dict self.assertEqual(actual_schema, expected_schema) + def test_to_provider_config_returns_response_schema(self): + """Test that to_provider_config returns the correct provider kwargs.""" + examples_data = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + ) + ], + ) + ] + + gemini_schema = schema.GeminiSchema.from_examples(examples_data) + provider_config = gemini_schema.to_provider_config() + + # Should contain response_schema key + self.assertIn("response_schema", provider_config) + self.assertEqual( + provider_config["response_schema"], gemini_schema.schema_dict + ) + + def test_supports_strict_mode_returns_true(self): + """Test that GeminiSchema supports strict mode.""" + examples_data = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="test_class", + extraction_text="test extraction", + ) + ], + ) + ] + + gemini_schema = schema.GeminiSchema.from_examples(examples_data) + self.assertTrue(gemini_schema.supports_strict_mode) + if __name__ == "__main__": absltest.main() diff --git a/tests/visualization_test.py b/tests/visualization_test.py index 647107f9..0e8c09c7 100644 --- a/tests/visualization_test.py +++ b/tests/visualization_test.py @@ -125,7 +125,6 @@ def test_visualize_basic_document_renders_correctly(self): f'style="background-color:{med_color};">MEDICATION' ) css_html = _VISUALIZATION_CSS - # Build expected components (adapted for animation format) expected_components = [ css_html, "lx-animated-wrapper", From 50ba182cbd0bf17829daee4ee2d6d9bc26e34f08 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 04:32:23 -0400 Subject: [PATCH 12/17] Support Hugging Face style model IDs for Ollama provider (#131) - Add HF-style patterns (e.g., meta-llama/Llama-3.2-1B-Instruct) to Ollama registry - Add gpt-oss pattern to support issue #116 - Add comprehensive documentation for using Ollama with extract() - Include example for direct provider instantiation when ID conflicts exist - Add test coverage for HF-style model ID patterns --- langextract/providers/ollama.py | 71 ++++++++++++++++++++++++++++++++- tests/registry_test.py | 36 +++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/langextract/providers/ollama.py b/langextract/providers/ollama.py index f3fd034f..9e928704 100644 --- a/langextract/providers/ollama.py +++ b/langextract/providers/ollama.py @@ -12,7 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Ollama provider for LangExtract.""" +"""Ollama provider for LangExtract. + +This provider enables using local Ollama models with LangExtract's extract() function. +No API key is required since Ollama runs locally on your machine. + +Usage with extract(): + import langextract as lx + from langextract.data import ExampleData, Extraction + + # Create an example for few-shot learning + example = ExampleData( + text="Marie Curie was a pioneering physicist and chemist.", + extractions=[ + Extraction( + extraction_class="person", + extraction_text="Marie Curie", + attributes={"name": "Marie Curie", "field": "physics and chemistry"} + ) + ] + ) + + # Basic usage with Ollama + result = lx.extract( + text_or_documents="Isaac Asimov was a prolific science fiction writer.", + model_id="gemma2:2b", + prompt_description="Extract the person's name and field", + examples=[example], + ) + +Direct provider instantiation (when model ID conflicts with other providers): + from langextract.providers.ollama import OllamaLanguageModel + + # Create Ollama provider directly + model = OllamaLanguageModel( + model_id="gemma2:2b", + model_url="http://localhost:11434", # optional, uses default if not specified + ) + + # Use with extract by passing the model instance + result = lx.extract( + text_or_documents="Your text here", + model=model, # Pass the model instance directly + prompt_description="Extract information", + examples=[example], + ) + +Supported model ID formats: + - Standard Ollama: llama3.2:1b, gemma2:2b, mistral:7b, qwen2.5:7b, etc. + - Hugging Face style: meta-llama/Llama-3.2-1B-Instruct, google/gemma-2b, etc. + +Prerequisites: + 1. Install Ollama: https://ollama.ai + 2. Pull the model: ollama pull gemma2:2b + 3. Ollama server will start automatically when you use extract() +""" # pylint: disable=cyclic-import,duplicate-code from __future__ import annotations @@ -33,7 +87,7 @@ @registry.register( - # Latest open models via Ollama (2024-2025) + # Standard Ollama naming patterns r'^gemma', # gemma2:2b, gemma2:9b, gemma2:27b, etc. r'^llama', # llama3.2:1b, llama3.2:3b, llama3.1:8b, llama3.1:70b, etc. r'^mistral', # mistral:7b, mistral-nemo:12b, mistral-large, etc. @@ -47,6 +101,19 @@ r'^codegemma', # codegemma:2b, codegemma:7b, etc. r'^tinyllama', # tinyllama:1.1b, etc. r'^wizardcoder', # wizardcoder:7b, wizardcoder:13b, wizardcoder:34b, etc. + r'^gpt-oss', # gpt-oss:20b, etc. + # Hugging Face style model IDs (organization/model-name) + r'^meta-llama/[Ll]lama', # meta-llama/Llama-3.2-1B-Instruct, etc. + r'^google/gemma', # google/gemma-2b, google/gemma-7b-it, etc. + r'^mistralai/[Mm]istral', # mistralai/Mistral-7B-v0.1, etc. + r'^mistralai/[Mm]ixtral', # mistralai/Mixtral-8x7B-v0.1, etc. + r'^microsoft/phi', # microsoft/phi-2, microsoft/phi-3-mini, etc. + r'^Qwen/', # Qwen/Qwen2.5-7B, Qwen/Qwen2.5-Coder, etc. + r'^deepseek-ai/', # deepseek-ai/deepseek-coder-v2, etc. + r'^bigcode/starcoder', # bigcode/starcoder2-3b, etc. + r'^codellama/', # codellama/CodeLlama-7b-Python, etc. + r'^TinyLlama/', # TinyLlama/TinyLlama-1.1B-Chat-v1.0, etc. + r'^WizardLM/', # WizardLM/WizardCoder-Python-7B-V1.0, etc. priority=10, ) @dataclasses.dataclass(init=False) diff --git a/tests/registry_test.py b/tests/registry_test.py index b42e79b9..999fbb41 100644 --- a/tests/registry_test.py +++ b/tests/registry_test.py @@ -190,6 +190,42 @@ def test_resolve_provider_not_found(self): registry.resolve_provider("UnknownProvider") self.assertIn("No provider found matching", str(cm.exception)) + def test_hf_style_model_id_patterns(self): + """Test that Hugging Face style model ID patterns work. + + This addresses issue #129 where HF-style model IDs like + 'meta-llama/Llama-3.2-1B-Instruct' weren't being recognized. + """ + + @registry.register( + r"^meta-llama/[Ll]lama", + r"^google/gemma", + r"^mistralai/[Mm]istral", + r"^microsoft/phi", + r"^Qwen/", + r"^TinyLlama/", + priority=100, + ) + class TestHFProvider(inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + return [] + + hf_model_ids = [ + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/llama-2-7b", + "google/gemma-2b", + "mistralai/Mistral-7B-v0.1", + "microsoft/phi-3-mini", + "Qwen/Qwen2.5-7B", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ] + + for model_id in hf_model_ids: + with self.subTest(model_id=model_id): + provider_class = registry.resolve(model_id) + self.assertEqual(provider_class, TestHFProvider) + if __name__ == "__main__": absltest.main() From 031c41cb067e1dda66953c3ba401cd8c8bf311d9 Mon Sep 17 00:00:00 2001 From: goelak Date: Wed, 13 Aug 2025 05:28:13 -0400 Subject: [PATCH 13/17] Add tests for Ollama format parameter handling Test that format='json' is correctly passed to Ollama API --- tests/provider_schema_test.py | 105 ++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/provider_schema_test.py b/tests/provider_schema_test.py index 27423338..1d726b3a 100644 --- a/tests/provider_schema_test.py +++ b/tests/provider_schema_test.py @@ -22,6 +22,7 @@ from langextract import exceptions from langextract import factory from langextract import schema +import langextract as lx from langextract.providers import gemini as gemini_provider from langextract.providers import ollama from langextract.providers import openai @@ -143,6 +144,110 @@ def test_different_examples_same_output(self): ) +class OllamaFormatParameterTest(absltest.TestCase): + """Tests for Ollama format parameter handling.""" + + def test_ollama_json_format_in_request_payload(self): + """Test that JSON format is passed to Ollama API by default.""" + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = {"response": '{"test": "value"}'} + mock_post.return_value = mock_response + + model = ollama.OllamaLanguageModel( + model_id="test-model", + format_type=data.FormatType.JSON, + ) + + list(model.infer(["Test prompt"])) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + payload = call_kwargs["json"] + + self.assertEqual(payload["format"], "json", msg="Format should be json") + self.assertEqual( + payload["model"], "test-model", msg="Model ID should match" + ) + self.assertEqual( + payload["prompt"], "Test prompt", msg="Prompt should match" + ) + self.assertFalse(payload["stream"], msg="Stream should be False") + + def test_ollama_default_format_is_json(self): + """Test that JSON is the default format when not specified.""" + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = {"response": '{"test": "value"}'} + mock_post.return_value = mock_response + + model = ollama.OllamaLanguageModel(model_id="test-model") + + list(model.infer(["Test prompt"])) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + payload = call_kwargs["json"] + + self.assertEqual( + payload["format"], "json", msg="Default format should be json" + ) + + def test_extract_with_ollama_passes_json_format(self): + """Test that lx.extract() correctly passes JSON format to Ollama API.""" + with mock.patch("requests.post", autospec=True) as mock_post: + mock_response = mock.Mock(spec=["status_code", "json"]) + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": ( + '{"extractions": [{"extraction_class": "test", "extraction_text":' + ' "example"}]}' + ) + } + mock_post.return_value = mock_response + + examples = [ + data.ExampleData( + text="Sample text", + extractions=[ + data.Extraction( + extraction_class="test", + extraction_text="sample", + ) + ], + ) + ] + + result = lx.extract( + text_or_documents="Test document", + prompt_description="Extract test information", + examples=examples, + model_id="gemma2:2b", + model_url="http://localhost:11434", + format_type=data.FormatType.JSON, + use_schema_constraints=True, + ) + + mock_post.assert_called() + + last_call = mock_post.call_args_list[-1] + payload = last_call[1]["json"] + + self.assertEqual( + payload["format"], + "json", + msg="Format should be json in extract() call", + ) + self.assertEqual( + payload["model"], "gemma2:2b", msg="Model ID should match" + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, data.AnnotatedDocument) + + class OllamaYAMLOverrideTest(absltest.TestCase): """Tests for Ollama YAML format override behavior.""" From 38528bf023786dcf31ac8608d9353a197fbbd27b Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 05:40:37 -0400 Subject: [PATCH 14/17] Update Ollama quickstart to use ModelConfig with JSON mode Show both ModelConfig and direct model_id usage patterns --- examples/ollama/quickstart.py | 86 +++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 20 deletions(-) diff --git a/examples/ollama/quickstart.py b/examples/ollama/quickstart.py index ed578412..6bc47c25 100644 --- a/examples/ollama/quickstart.py +++ b/examples/ollama/quickstart.py @@ -17,6 +17,7 @@ import argparse import os +import sys import langextract as lx @@ -36,8 +37,10 @@ def run_extraction(model_id="gemma2:2b", temperature=0.3): extractions=[ lx.data.Extraction( extraction_class="author_details", - # extraction_text includes full context with ellipsis for clarity - extraction_text="J.R.R. Tolkien was an English writer...", + extraction_text=( + "J.R.R. Tolkien was an English writer, best known for" + " high-fantasy." + ), attributes={ "name": "J.R.R. Tolkien", "genre": "high-fantasy", @@ -47,24 +50,50 @@ def run_extraction(model_id="gemma2:2b", temperature=0.3): ) ] + # Option 1: Use ModelConfig for explicit configuration + # This gives you full control over provider-specific settings + model_config = lx.factory.ModelConfig( + model_id=model_id, + provider_kwargs={ + "model_url": os.getenv("OLLAMA_HOST", "http://localhost:11434"), + "format_type": lx.data.FormatType.JSON, + "temperature": temperature, + }, + ) + result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, - language_model_type=lx.inference.OllamaLanguageModel, - model_id=model_id, - model_url=os.getenv("OLLAMA_HOST", "http://localhost:11434"), - temperature=temperature, - fence_output=False, - use_schema_constraints=False, + config=model_config, + use_schema_constraints=True, ) + # Option 2 (simpler): Just pass model_id directly + # LangExtract's registry automatically identifies Ollama models like "gemma2:2b" + # result = lx.extract( + # text_or_documents=input_text, + # prompt_description=prompt, + # examples=examples, + # model_id=model_id, + # model_url=os.getenv("OLLAMA_HOST", "http://localhost:11434"), + # format_type=lx.data.FormatType.JSON, + # temperature=temperature, + # use_schema_constraints=True, + # ) + return result def main(): """Main function to run the quick-start example.""" - parser = argparse.ArgumentParser(description="Run Ollama extraction example") + parser = argparse.ArgumentParser( + description="Run Ollama extraction example", + epilog=( + "Supported models: gemma2:2b, llama3.2:1b, mistral:7b, qwen2.5:0.5b," + " etc." + ), + ) parser.add_argument( "--model-id", default=os.getenv("MODEL_ID", "gemma2:2b"), @@ -86,23 +115,40 @@ def main(): model_id=args.model_id, temperature=args.temperature ) - for extraction in result.extractions: - print(f"Class: {extraction.extraction_class}") - print(f"Text: {extraction.extraction_text}") - print(f"Attributes: {extraction.attributes}") - - print("\n✅ SUCCESS! Ollama is working with langextract") + if result.extractions: + print(f"\n📝 Found {len(result.extractions)} extraction(s):\n") + for extraction in result.extractions: + print(f"Class: {extraction.extraction_class}") + print(f"Text: {extraction.extraction_text}") + print(f"Attributes: {extraction.attributes}") + print() + else: + print("\n⚠️ No extractions found") + + print("✅ SUCCESS! Ollama is working with langextract") + print(f" Model: {args.model_id}") + print(" JSON mode: enabled") + print(" Schema constraints: enabled") return True except ConnectionError as e: - print(f"\nConnectionError: {e}") - print("Make sure Ollama is running: 'ollama serve'") + print(f"\n❌ ConnectionError: {e}") + print("\n💡 Make sure Ollama is running:") + print(" ollama serve") + return False + except ValueError as e: + if "Can't find Ollama" in str(e): + print(f"\n❌ Model not found: {args.model_id}") + print("\n💡 Install the model first:") + print(f" ollama pull {args.model_id}") + else: + print(f"\n❌ ValueError: {e}") return False except Exception as e: - print(f"\nError: {type(e).__name__}: {e}") + print(f"\n❌ Error: {type(e).__name__}: {e}") return False if __name__ == "__main__": - success = main() - exit(0 if success else 1) + SUCCESS = main() + sys.exit(0 if SUCCESS else 1) From e056a85d1cabb907b45ff56ae22a34bbcb79c6d3 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 05:57:07 -0400 Subject: [PATCH 15/17] Fix test failure by loading providers for registry (#133) --- tests/provider_schema_test.py | 72 ++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/tests/provider_schema_test.py b/tests/provider_schema_test.py index 1d726b3a..c2dffa15 100644 --- a/tests/provider_schema_test.py +++ b/tests/provider_schema_test.py @@ -208,44 +208,48 @@ def test_extract_with_ollama_passes_json_format(self): } mock_post.return_value = mock_response - examples = [ - data.ExampleData( - text="Sample text", - extractions=[ - data.Extraction( - extraction_class="test", - extraction_text="sample", - ) - ], - ) - ] - - result = lx.extract( - text_or_documents="Test document", - prompt_description="Extract test information", - examples=examples, - model_id="gemma2:2b", - model_url="http://localhost:11434", - format_type=data.FormatType.JSON, - use_schema_constraints=True, - ) + # Mock the registry to return OllamaLanguageModel + with mock.patch("langextract.providers.registry.resolve") as mock_resolve: + mock_resolve.return_value = ollama.OllamaLanguageModel - mock_post.assert_called() + examples = [ + data.ExampleData( + text="Sample text", + extractions=[ + data.Extraction( + extraction_class="test", + extraction_text="sample", + ) + ], + ) + ] + + result = lx.extract( + text_or_documents="Test document", + prompt_description="Extract test information", + examples=examples, + model_id="gemma2:2b", + model_url="http://localhost:11434", + format_type=data.FormatType.JSON, + use_schema_constraints=True, + ) - last_call = mock_post.call_args_list[-1] - payload = last_call[1]["json"] + mock_post.assert_called() - self.assertEqual( - payload["format"], - "json", - msg="Format should be json in extract() call", - ) - self.assertEqual( - payload["model"], "gemma2:2b", msg="Model ID should match" - ) + last_call = mock_post.call_args_list[-1] + payload = last_call[1]["json"] + + self.assertEqual( + payload["format"], + "json", + msg="Format should be json in extract() call", + ) + self.assertEqual( + payload["model"], "gemma2:2b", msg="Model ID should match" + ) - self.assertIsNotNone(result) - self.assertIsInstance(result, data.AnnotatedDocument) + self.assertIsNotNone(result) + self.assertIsInstance(result, data.AnnotatedDocument) class OllamaYAMLOverrideTest(absltest.TestCase): From 6a91ce13ed74d1b4ab1c1105864be0efbf902670 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 06:03:37 -0400 Subject: [PATCH 16/17] Update CITATION.cff abstract for accuracy --- CITATION.cff | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index 2eb3134a..09969746 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -16,7 +16,7 @@ authors: repository-code: "https://github.com/google/langextract" url: "https://github.com/google/langextract" repository: "https://github.com/google/langextract" -abstract: "LangExtract: A library for extracting structured data from language models" +abstract: "LangExtract: LLM-powered structured information extraction from text with source grounding" keywords: - language-models - structured-data-extraction From bdcd41650938e0cf338d6a2764beda575cb042e2 Mon Sep 17 00:00:00 2001 From: Akshay Goel Date: Wed, 13 Aug 2025 06:13:08 -0400 Subject: [PATCH 17/17] Bump version to 1.0.6 (#134) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 53ccff96..1a0db1a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ build-backend = "setuptools.build_meta" [project] name = "langextract" -version = "1.0.5" +version = "1.0.6" description = "LangExtract: A library for extracting structured data from language models" readme = "README.md" requires-python = ">=3.10"