From 4ff367251fde502594b1269831970b592dbe980b Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Fri, 10 Oct 2025 10:03:25 -0400 Subject: [PATCH] feat: add OpenAI-compatible Bedrock provider with error handling Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration --- .../providers/inference/remote_bedrock.mdx | 19 +- .../distributions/ci-tests/run.yaml | 3 + .../distributions/starter-gpu/run.yaml | 3 + .../distributions/starter/run.yaml | 3 + .../providers/registry/inference.py | 5 +- .../remote/inference/bedrock/__init__.py | 2 +- .../remote/inference/bedrock/bedrock.py | 171 +++++++----------- .../remote/inference/bedrock/config.py | 31 +++- .../remote/inference/bedrock/models.py | 29 --- .../inference/test_bedrock_adapter.py | 81 +++++++++ .../inference/test_bedrock_config.py | 41 +++++ tests/unit/providers/test_bedrock.py | 87 +++++---- 12 files changed, 288 insertions(+), 187 deletions(-) delete mode 100644 src/llama_stack/providers/remote/inference/bedrock/models.py create mode 100644 tests/unit/providers/inference/test_bedrock_adapter.py create mode 100644 tests/unit/providers/inference/test_bedrock_config.py diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 683ec12f8..c6804e9c5 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -1,5 +1,5 @@ --- -description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service." +description: "AWS Bedrock inference provider using OpenAI compatible endpoint." sidebar_label: Remote - Bedrock title: remote::bedrock --- @@ -8,7 +8,7 @@ title: remote::bedrock ## Description -AWS Bedrock inference provider for accessing various AI models through AWS's managed service. +AWS Bedrock inference provider using OpenAI compatible endpoint. ## Configuration @@ -16,19 +16,12 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | -| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | -| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | -| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION | -| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | -| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | -| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | -| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | +| `api_key` | `str \| None` | No | | Amazon Bedrock API key | +| `region_name` | `` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint | ## Sample Configuration ```yaml -{} +api_key: ${env.AWS_BEDROCK_API_KEY:=} +region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} ``` diff --git a/src/llama_stack/distributions/ci-tests/run.yaml b/src/llama_stack/distributions/ci-tests/run.yaml index 702acff8e..1118d2ad1 100644 --- a/src/llama_stack/distributions/ci-tests/run.yaml +++ b/src/llama_stack/distributions/ci-tests/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter-gpu/run.yaml b/src/llama_stack/distributions/starter-gpu/run.yaml index 807f0d678..7149b8659 100644 --- a/src/llama_stack/distributions/starter-gpu/run.yaml +++ b/src/llama_stack/distributions/starter-gpu/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter/run.yaml b/src/llama_stack/distributions/starter/run.yaml index eb4652af0..0ce392810 100644 --- a/src/llama_stack/distributions/starter/run.yaml +++ b/src/llama_stack/distributions/starter/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/providers/registry/inference.py b/src/llama_stack/providers/registry/inference.py index 00967a8ec..1b70182fc 100644 --- a/src/llama_stack/providers/registry/inference.py +++ b/src/llama_stack/providers/registry/inference.py @@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter_type="bedrock", provider_type="remote::bedrock", - pip_packages=["boto3"], + pip_packages=[], module="llama_stack.providers.remote.inference.bedrock", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", + provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator", + description="AWS Bedrock inference provider using OpenAI compatible endpoint.", ), RemoteProviderSpec( api=Api.inference, diff --git a/src/llama_stack/providers/remote/inference/bedrock/__init__.py b/src/llama_stack/providers/remote/inference/bedrock/__init__.py index 4d98f4999..4b0686b18 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/__init__.py +++ b/src/llama_stack/providers/remote/inference/bedrock/__init__.py @@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps): assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}" - impl = BedrockInferenceAdapter(config) + impl = BedrockInferenceAdapter(config=config) await impl.initialize() diff --git a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py index d266f9e6f..a4093aefe 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,139 +4,106 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable -from botocore.client import BaseClient +from openai import AuthenticationError from llama_stack.apis.inference import ( - ChatCompletionRequest, - Inference, + Model, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletion, OpenAICompletionRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, ) -from llama_stack.apis.inference.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, -) -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) -from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_strategy_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, -) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin +from llama_stack.providers.utils.telemetry.tracing import get_current_span -from .models import MODEL_ENTRIES - -REGION_PREFIX_MAP = { - "us": "us.", - "eu": "eu.", - "ap": "ap.", -} +from .config import BedrockConfig -def _get_region_prefix(region: str | None) -> str: - # AWS requires region prefixes for inference profiles - if region is None: - return "us." # default to US when we don't know +class BedrockInferenceAdapter(OpenAIMixin): + """ + Adapter for AWS Bedrock's OpenAI-compatible API endpoints. - # Handle case insensitive region matching - region_lower = region.lower() - for prefix in REGION_PREFIX_MAP: - if region_lower.startswith(f"{prefix}-"): - return REGION_PREFIX_MAP[prefix] + Supports Llama models across regions and GPT-OSS models (us-west-2 only). - # Fallback to US for anything we don't recognize - return "us." + Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models + for dynamic model discovery. Models must be pre-registered in the config. + """ + config: BedrockConfig + provider_data_api_key_field: str = "aws_bedrock_api_key" -def _to_inference_profile_id(model_id: str, region: str = None) -> str: - # Return ARNs unchanged - if model_id.startswith("arn:"): - return model_id + def get_api_key(self) -> str: + """Get API key for OpenAI client.""" + if not self.config.api_key: + raise ValueError( + "API key is not set. Please provide a valid API key in the " + "provider config or via AWS_BEDROCK_API_KEY environment variable." + ) + return self.config.api_key - # Return inference profile IDs that already have regional prefixes - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): - return model_id + def get_base_url(self) -> str: + """Get base URL for OpenAI client.""" + return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1" - # Default to US East when no region is provided - if region is None: - region = "us-east-1" + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint. + Returns empty list since models must be pre-registered in the config. + """ + return [] - return _get_region_prefix(region) + model_id + async def register_model(self, model: Model) -> Model: + """ + Register a model with the Bedrock provider. - -class BedrockInferenceAdapter( - ModelRegistryHelper, - Inference, -): - def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - self._config = config - self._client = None - - @property - def client(self) -> BaseClient: - if self._client is None: - self._client = create_bedrock_client(self._config) - return self._client - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - if self._client is not None: - self._client.close() - - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: - bedrock_model = request.model - - sampling_params = request.sampling_params - options = get_sampling_strategy_options(sampling_params) - - if sampling_params.max_tokens: - options["max_gen_len"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - options["repetition_penalty"] = sampling_params.repetition_penalty - - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) - - # Convert foundation model ID to inference profile ID - region_name = self.client.meta.region_name - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) - - return { - "modelId": inference_profile_id, - "body": json.dumps( - { - "prompt": prompt, - **options, - } - ), - } + Bedrock doesn't support dynamic model listing via /v1/models, so we skip + the availability check and accept all models registered in the config. + """ + return model async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + """Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") + """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " + "Only /v1/chat/completions is supported. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") + """Override to enable streaming usage metrics and handle authentication errors.""" + # Enable streaming usage metrics when telemetry is active + if params.stream and get_current_span() is not None: + if params.stream_options is None: + params.stream_options = {"include_usage": True} + elif "include_usage" not in params.stream_options: + params.stream_options = {**params.stream_options, "include_usage": True} + + # Wrap call in try/except to catch authentication errors + try: + return await super().openai_chat_completion(params=params) + except AuthenticationError as e: + raise ValueError( + f"AWS Bedrock authentication failed: {e.message}. " + "Please check your API key in the provider config or x-llamastack-provider-data header." + ) from e diff --git a/src/llama_stack/providers/remote/inference/bedrock/config.py b/src/llama_stack/providers/remote/inference/bedrock/config.py index 5961a2f15..2b236e902 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/config.py +++ b/src/llama_stack/providers/remote/inference/bedrock/config.py @@ -4,8 +4,33 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +import os + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig -class BedrockConfig(BedrockBaseConfig): - pass +class BedrockProviderDataValidator(BaseModel): + aws_bedrock_api_key: str | None = Field( + default=None, + description="API key for Amazon Bedrock", + ) + + +class BedrockConfig(RemoteInferenceProviderConfig): + api_key: str | None = Field( + default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"), + description="Amazon Bedrock API key", + ) + region_name: str = Field( + default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"), + description="AWS Region for the Bedrock Runtime endpoint", + ) + + @classmethod + def sample_run_config(cls, **kwargs): + return { + "api_key": "${env.AWS_BEDROCK_API_KEY:=}", + "region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}", + } diff --git a/src/llama_stack/providers/remote/inference/bedrock/models.py b/src/llama_stack/providers/remote/inference/bedrock/models.py deleted file mode 100644 index 17273c122..000000000 --- a/src/llama_stack/providers/remote/inference/bedrock/models.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - - -# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta.llama3-1-8b-instruct-v1:0", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-70b-instruct-v1:0", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-405b-instruct-v1:0", - CoreModelId.llama3_1_405b_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/tests/unit/providers/inference/test_bedrock_adapter.py b/tests/unit/providers/inference/test_bedrock_adapter.py new file mode 100644 index 000000000..2144059d6 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_adapter.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from openai import AuthenticationError + +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_adapter_initialization(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.config.api_key == "test-key" + assert adapter.config.region_name == "us-east-1" + + +def test_client_url_construction(): + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1" + assert adapter.get_api_key() == "test-key" + + +def test_api_key_from_config(): + """Test API key is read from config""" + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.get_api_key() == "config-key" + + +def test_api_key_from_header_overrides_config(): + """Test API key from request header overrides config via client property""" + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + adapter.provider_data_api_key_field = "aws_bedrock_api_key" + adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key")) + + # The client property is where header override happens (in OpenAIMixin) + assert adapter.client.api_key == "header-key" + + +async def test_authentication_error_handling(): + """Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message""" + config = BedrockConfig(api_key="invalid-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + # Mock the parent class method to raise AuthenticationError + mock_response = MagicMock() + mock_response.message = "Invalid authentication credentials" + auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None) + + # Create a mock that raises the error + mock_super = AsyncMock(side_effect=auth_error) + + # Patch the parent class method + original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super + + try: + with pytest.raises(ValueError) as exc_info: + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[{"role": "user", "content": "test"}] + ) + await adapter.openai_chat_completion(params=params) + + assert "AWS Bedrock authentication failed" in str(exc_info.value) + assert "Please check your API key" in str(exc_info.value) + finally: + # Restore original method + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method diff --git a/tests/unit/providers/inference/test_bedrock_config.py b/tests/unit/providers/inference/test_bedrock_config.py new file mode 100644 index 000000000..4d97900e4 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_config.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_bedrock_config_defaults_no_env(monkeypatch): + """Test BedrockConfig defaults when env vars are not set""" + monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False) + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + config = BedrockConfig() + assert config.api_key is None + assert config.region_name == "us-east-2" + + +def test_bedrock_config_defaults_with_env(monkeypatch): + """Test BedrockConfig reads from environment variables""" + monkeypatch.setenv("AWS_BEDROCK_API_KEY", "env-key") + monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1") + config = BedrockConfig() + assert config.api_key == "env-key" + assert config.region_name == "eu-west-1" + + +def test_bedrock_config_with_values(): + """Test BedrockConfig accepts explicit values""" + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + assert config.api_key == "test-key" + assert config.region_name == "us-west-2" + + +def test_bedrock_config_sample(): + """Test BedrockConfig sample_run_config returns correct format""" + sample = BedrockConfig.sample_run_config() + assert "api_key" in sample + assert "region_name" in sample + assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}" + assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}" diff --git a/tests/unit/providers/test_bedrock.py b/tests/unit/providers/test_bedrock.py index 1ff07bbbe..18ab91258 100644 --- a/tests/unit/providers/test_bedrock.py +++ b/tests/unit/providers/test_bedrock.py @@ -4,50 +4,63 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.remote.inference.bedrock.bedrock import ( - _get_region_prefix, - _to_inference_profile_id, -) +from types import SimpleNamespace +from unittest.mock import AsyncMock, PropertyMock, patch + +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -def test_region_prefixes(): - assert _get_region_prefix("us-east-1") == "us." - assert _get_region_prefix("eu-west-1") == "eu." - assert _get_region_prefix("ap-south-1") == "ap." - assert _get_region_prefix("ca-central-1") == "us." +def test_can_create_adapter(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) - # Test case insensitive - assert _get_region_prefix("US-EAST-1") == "us." - assert _get_region_prefix("EU-WEST-1") == "eu." - assert _get_region_prefix("Ap-South-1") == "ap." - - # Test None region - assert _get_region_prefix(None) == "us." + assert adapter is not None + assert adapter.config.region_name == "us-east-1" + assert adapter.get_api_key() == "test-key" -def test_model_id_conversion(): - # Basic conversion - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0" +def test_different_aws_regions(): + # just check a couple regions to verify URL construction works + config = BedrockConfig(api_key="key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1" + + config = BedrockConfig(api_key="key", region_name="eu-west-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1" + + +async def test_basic_chat_completion(): + """Test basic chat completion works with OpenAIMixin""" + config = BedrockConfig(api_key="k", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + class FakeModelStore: + async def get_model(self, model_id): + return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0") + + adapter.model_store = FakeModelStore() + + fake_response = SimpleNamespace( + id="chatcmpl-123", + choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")], ) - # Already has prefix - assert ( - _to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1") - == "us.meta.llama3-1-70b-instruct-v1:0" - ) + mock_create = AsyncMock(return_value=fake_response) - # ARN should be returned unchanged - arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0" - assert _to_inference_profile_id(arn, "us-east-1") == arn + class FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create)) - # ARN should be returned unchanged even without region - assert _to_inference_profile_id(arn) == arn + with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()): + params = OpenAIChatCompletionRequestWithExtraBody( + model="llama3-1-8b", + messages=[{"role": "user", "content": "hello"}], + stream=False, + ) + response = await adapter.openai_chat_completion(params=params) - # Optional region parameter defaults to us-east-1 - assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0" - - # Different regions work with optional parameter - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0" - ) + assert response.id == "chatcmpl-123" + assert mock_create.await_count == 1