From e894e36eea5c49db88a241ac95ef24f1aa7183fc Mon Sep 17 00:00:00 2001 From: Sumanth Kamenani Date: Thu, 6 Nov 2025 20:18:18 -0500 Subject: [PATCH] feat: add OpenAI-compatible Bedrock provider (#3748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Closes: #3410 ## What does this PR do? Adds AWS Bedrock as an inference provider using the OpenAI-compatible endpoint. This lets us use Bedrock models (GPT-OSS, Llama) through the standard llama-stack inference API. The implementation uses LiteLLM's OpenAI client under the hood, so it gets all the OpenAI compatibility features. The provider handles per-request API key overrides via headers. ## Test Plan **Tested the following scenarios:** - Non-streaming completion - basic request/response flow - Streaming completion - SSE streaming with chunked responses - Multi-turn conversations - context retention across turns - Tool calling - function calling with proper tool_calls format # Bedrock OpenAI-Compatible Provider - Test Results **Model:** `bedrock-inference/openai.gpt-oss-20b-1:0` --- ## Test 1: Model Listing **Request:** ```http GET /v1/models HTTP/1.1 ``` **Response:** ```http HTTP/1.1 200 OK Content-Type: application/json { "data": [ {"identifier": "bedrock-inference/openai.gpt-oss-20b-1:0", ...}, {"identifier": "bedrock-inference/openai.gpt-oss-40b-1:0", ...} ] } ``` --- ## Test 2: Non-Streaming Completion **Request:** ```http POST /v1/chat/completions HTTP/1.1 Content-Type: application/json { "model": "bedrock-inference/openai.gpt-oss-20b-1:0", "messages": [{"role": "user", "content": "Say 'Hello from Bedrock' and nothing else"}], "stream": false } ``` **Response:** ```http HTTP/1.1 200 OK Content-Type: application/json { "choices": [{ "finish_reason": "stop", "message": {"content": "...Hello from Bedrock"} }], "usage": {"prompt_tokens": 79, "completion_tokens": 50, "total_tokens": 129} } ``` --- ## Test 3: Streaming Completion **Request:** ```http POST /v1/chat/completions HTTP/1.1 Content-Type: application/json { "model": "bedrock-inference/openai.gpt-oss-20b-1:0", "messages": [{"role": "user", "content": "Count from 1 to 5"}], "stream": true } ``` **Response:** ```http HTTP/1.1 200 OK Content-Type: text/event-stream [6 SSE chunks received] Final content: "1, 2, 3, 4, 5" ``` --- ## Test 4: Error Handling - Invalid Model **Request:** ```http POST /v1/chat/completions HTTP/1.1 Content-Type: application/json { "model": "invalid-model-id", "messages": [{"role": "user", "content": "Hello"}], "stream": false } ``` **Response:** ```http HTTP/1.1 404 Not Found Content-Type: application/json { "detail": "Model 'invalid-model-id' not found. Use 'client.models.list()' to list available Models." } ``` --- ## Test 5: Multi-Turn Conversation **Request 1:** ```http POST /v1/chat/completions HTTP/1.1 { "messages": [{"role": "user", "content": "My name is Alice"}] } ``` **Response 1:** ```http HTTP/1.1 200 OK { "choices": [{ "message": {"content": "...Nice to meet you, Alice! How can I help you today?"} }] } ``` **Request 2 (with history):** ```http POST /v1/chat/completions HTTP/1.1 { "messages": [ {"role": "user", "content": "My name is Alice"}, {"role": "assistant", "content": "...Nice to meet you, Alice!..."}, {"role": "user", "content": "What is my name?"} ] } ``` **Response 2:** ```http HTTP/1.1 200 OK { "choices": [{ "message": {"content": "...Your name is Alice."} }], "usage": {"prompt_tokens": 183, "completion_tokens": 42} } ``` **Context retained across turns** --- ## Test 6: System Messages **Request:** ```http POST /v1/chat/completions HTTP/1.1 { "messages": [ {"role": "system", "content": "You are Shakespeare. Respond only in Shakespearean English."}, {"role": "user", "content": "Tell me about the weather"} ] } ``` **Response:** ```http HTTP/1.1 200 OK { "choices": [{ "message": {"content": "Lo! I heed thy request..."} }], "usage": {"completion_tokens": 813} } ``` --- ## Test 7: Tool Calling **Request:** ```http POST /v1/chat/completions HTTP/1.1 { "messages": [{"role": "user", "content": "What's the weather in San Francisco?"}], "tools": [{ "type": "function", "function": { "name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}} } }] } ``` **Response:** ```http HTTP/1.1 200 OK { "choices": [{ "finish_reason": "tool_calls", "message": { "tool_calls": [{ "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"} }] } }] } ``` --- ## Test 8: Sampling Parameters **Request:** ```http POST /v1/chat/completions HTTP/1.1 { "messages": [{"role": "user", "content": "Say hello"}], "temperature": 0.7, "top_p": 0.9 } ``` **Response:** ```http HTTP/1.1 200 OK { "choices": [{ "message": {"content": "...Hello! 👋 How can I help you today?"} }] } ``` --- ## Test 9: Authentication Error Handling ### Subtest A: Invalid API Key **Request:** ```http POST /v1/chat/completions HTTP/1.1 x-llamastack-provider-data: {"aws_bedrock_api_key": "invalid-fake-key-12345"} {"model": "bedrock-inference/openai.gpt-oss-20b-1:0", ...} ``` **Response:** ```http HTTP/1.1 400 Bad Request { "detail": "Invalid value: Authentication failed: Error code: 401 - {'error': {'message': 'Invalid API Key format: Must start with pre-defined prefix', ...}}" } ``` --- ### Subtest B: Empty API Key (Fallback to Config) **Request:** ```http POST /v1/chat/completions HTTP/1.1 x-llamastack-provider-data: {"aws_bedrock_api_key": ""} {"model": "bedrock-inference/openai.gpt-oss-20b-1:0", ...} ``` **Response:** ```http HTTP/1.1 200 OK { "choices": [{ "message": {"content": "...Hello! How can I assist you today?"} }] } ``` **Fell back to config key** --- ### Subtest C: Malformed Token **Request:** ```http POST /v1/chat/completions HTTP/1.1 x-llamastack-provider-data: {"aws_bedrock_api_key": "not-a-valid-bedrock-token-format"} {"model": "bedrock-inference/openai.gpt-oss-20b-1:0", ...} ``` **Response:** ```http HTTP/1.1 400 Bad Request { "detail": "Invalid value: Authentication failed: Error code: 401 - {'error': {'message': 'Invalid API Key format: Must start with pre-defined prefix', ...}}" } ``` --- .../providers/inference/remote_bedrock.mdx | 19 +- src/llama_stack/core/routers/inference.py | 4 +- .../distributions/ci-tests/run.yaml | 3 + .../starter-gpu/run-with-postgres-store.yaml | 3 + .../distributions/starter-gpu/run.yaml | 3 + .../starter/run-with-postgres-store.yaml | 3 + .../distributions/starter/run.yaml | 3 + .../providers/registry/inference.py | 5 +- .../remote/inference/bedrock/__init__.py | 2 +- .../remote/inference/bedrock/bedrock.py | 191 ++++++++---------- .../remote/inference/bedrock/config.py | 27 ++- .../remote/inference/bedrock/models.py | 29 --- .../inference/test_bedrock_adapter.py | 78 +++++++ .../inference/test_bedrock_config.py | 39 ++++ tests/unit/providers/test_bedrock.py | 90 +++++---- 15 files changed, 309 insertions(+), 190 deletions(-) delete mode 100644 src/llama_stack/providers/remote/inference/bedrock/models.py create mode 100644 tests/unit/providers/inference/test_bedrock_adapter.py create mode 100644 tests/unit/providers/inference/test_bedrock_config.py diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 683ec12f8..61931643e 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -1,5 +1,5 @@ --- -description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service." +description: "AWS Bedrock inference provider using OpenAI compatible endpoint." sidebar_label: Remote - Bedrock title: remote::bedrock --- @@ -8,7 +8,7 @@ title: remote::bedrock ## Description -AWS Bedrock inference provider for accessing various AI models through AWS's managed service. +AWS Bedrock inference provider using OpenAI compatible endpoint. ## Configuration @@ -16,19 +16,12 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | -| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | -| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | -| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION | -| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | -| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | -| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | -| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | +| `region_name` | `` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint | ## Sample Configuration ```yaml -{} +api_key: ${env.AWS_BEDROCK_API_KEY:=} +region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} ``` diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index a4f0f4411..d6270d428 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -190,7 +190,7 @@ class InferenceRouter(Inference): response = await provider.openai_completion(params) response.model = request_model_id - if self.telemetry_enabled: + if self.telemetry_enabled and response.usage is not None: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -253,7 +253,7 @@ class InferenceRouter(Inference): if self.store: asyncio.create_task(self.store.store_chat_completion(response, params.messages)) - if self.telemetry_enabled: + if self.telemetry_enabled and response.usage is not None: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, diff --git a/src/llama_stack/distributions/ci-tests/run.yaml b/src/llama_stack/distributions/ci-tests/run.yaml index 702acff8e..1118d2ad1 100644 --- a/src/llama_stack/distributions/ci-tests/run.yaml +++ b/src/llama_stack/distributions/ci-tests/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml index 6dbbc8716..1920ebd9d 100644 --- a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter-gpu/run.yaml b/src/llama_stack/distributions/starter-gpu/run.yaml index 807f0d678..7149b8659 100644 --- a/src/llama_stack/distributions/starter-gpu/run.yaml +++ b/src/llama_stack/distributions/starter-gpu/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml index 530084bd9..702f95381 100644 --- a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/distributions/starter/run.yaml b/src/llama_stack/distributions/starter/run.yaml index eb4652af0..0ce392810 100644 --- a/src/llama_stack/distributions/starter/run.yaml +++ b/src/llama_stack/distributions/starter/run.yaml @@ -46,6 +46,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/src/llama_stack/providers/registry/inference.py b/src/llama_stack/providers/registry/inference.py index 00967a8ec..1b70182fc 100644 --- a/src/llama_stack/providers/registry/inference.py +++ b/src/llama_stack/providers/registry/inference.py @@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter_type="bedrock", provider_type="remote::bedrock", - pip_packages=["boto3"], + pip_packages=[], module="llama_stack.providers.remote.inference.bedrock", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", + provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator", + description="AWS Bedrock inference provider using OpenAI compatible endpoint.", ), RemoteProviderSpec( api=Api.inference, diff --git a/src/llama_stack/providers/remote/inference/bedrock/__init__.py b/src/llama_stack/providers/remote/inference/bedrock/__init__.py index 4d98f4999..4b0686b18 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/__init__.py +++ b/src/llama_stack/providers/remote/inference/bedrock/__init__.py @@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps): assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}" - impl = BedrockInferenceAdapter(config) + impl = BedrockInferenceAdapter(config=config) await impl.initialize() diff --git a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py index d266f9e6f..1bf44b51a 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,139 +4,124 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable -from botocore.client import BaseClient +from openai import AuthenticationError from llama_stack.apis.inference import ( - ChatCompletionRequest, - Inference, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletion, OpenAICompletionRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, ) -from llama_stack.apis.inference.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, -) -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) -from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_strategy_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, -) +from llama_stack.core.telemetry.tracing import get_current_span +from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from .models import MODEL_ENTRIES +from .config import BedrockConfig -REGION_PREFIX_MAP = { - "us": "us.", - "eu": "eu.", - "ap": "ap.", -} +logger = get_logger(name=__name__, category="inference::bedrock") -def _get_region_prefix(region: str | None) -> str: - # AWS requires region prefixes for inference profiles - if region is None: - return "us." # default to US when we don't know +class BedrockInferenceAdapter(OpenAIMixin): + """ + Adapter for AWS Bedrock's OpenAI-compatible API endpoints. - # Handle case insensitive region matching - region_lower = region.lower() - for prefix in REGION_PREFIX_MAP: - if region_lower.startswith(f"{prefix}-"): - return REGION_PREFIX_MAP[prefix] + Supports Llama models across regions and GPT-OSS models (us-west-2 only). - # Fallback to US for anything we don't recognize - return "us." + Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models + for dynamic model discovery. Models must be pre-registered in the config. + """ + config: BedrockConfig + provider_data_api_key_field: str = "aws_bedrock_api_key" -def _to_inference_profile_id(model_id: str, region: str = None) -> str: - # Return ARNs unchanged - if model_id.startswith("arn:"): - return model_id + def get_base_url(self) -> str: + """Get base URL for OpenAI client.""" + return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1" - # Return inference profile IDs that already have regional prefixes - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): - return model_id + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint. + Returns empty list since models must be pre-registered in the config. + """ + return [] - # Default to US East when no region is provided - if region is None: - region = "us-east-1" - - return _get_region_prefix(region) + model_id - - -class BedrockInferenceAdapter( - ModelRegistryHelper, - Inference, -): - def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - self._config = config - self._client = None - - @property - def client(self) -> BaseClient: - if self._client is None: - self._client = create_bedrock_client(self._config) - return self._client - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - if self._client is not None: - self._client.close() - - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: - bedrock_model = request.model - - sampling_params = request.sampling_params - options = get_sampling_strategy_options(sampling_params) - - if sampling_params.max_tokens: - options["max_gen_len"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - options["repetition_penalty"] = sampling_params.repetition_penalty - - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) - - # Convert foundation model ID to inference profile ID - region_name = self.client.meta.region_name - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) - - return { - "modelId": inference_profile_id, - "body": json.dumps( - { - "prompt": prompt, - **options, - } - ), - } + async def check_model_availability(self, model: str) -> bool: + """ + Bedrock doesn't support dynamic model listing via /v1/models. + Always return True to accept all models registered in the config. + """ + return True async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + """Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") + """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " + "Only /v1/chat/completions is supported. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") + """Override to enable streaming usage metrics and handle authentication errors.""" + # Enable streaming usage metrics when telemetry is active + if params.stream and get_current_span() is not None: + if params.stream_options is None: + params.stream_options = {"include_usage": True} + elif "include_usage" not in params.stream_options: + params.stream_options = {**params.stream_options, "include_usage": True} + + try: + logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}") + result = await super().openai_chat_completion(params=params) + logger.debug(f"Bedrock API returned: {type(result).__name__ if result is not None else 'None'}") + + if result is None: + logger.error(f"Bedrock OpenAI client returned None for model={params.model}, stream={params.stream}") + raise RuntimeError( + f"Bedrock API returned no response for model '{params.model}'. " + "This may indicate the model is not supported or a network/API issue occurred." + ) + + return result + except AuthenticationError as e: + error_msg = str(e) + + # Check if this is a token expiration error + if "expired" in error_msg.lower() or "Bearer Token has expired" in error_msg: + logger.error(f"AWS Bedrock authentication token expired: {error_msg}") + raise ValueError( + "AWS Bedrock authentication failed: Bearer token has expired. " + "The AWS_BEDROCK_API_KEY environment variable contains an expired pre-signed URL. " + "Please refresh your token by generating a new pre-signed URL with AWS credentials. " + "Refer to AWS Bedrock documentation for details on OpenAI-compatible endpoints." + ) from e + else: + logger.error(f"AWS Bedrock authentication failed: {error_msg}") + raise ValueError( + f"AWS Bedrock authentication failed: {error_msg}. " + "Please verify your API key is correct in the provider config or x-llamastack-provider-data header. " + "The API key should be a valid AWS pre-signed URL for Bedrock's OpenAI-compatible endpoint." + ) from e + except Exception as e: + logger.error(f"Unexpected error calling Bedrock API: {type(e).__name__}: {e}", exc_info=True) + raise diff --git a/src/llama_stack/providers/remote/inference/bedrock/config.py b/src/llama_stack/providers/remote/inference/bedrock/config.py index 5961a2f15..631a6e7ef 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/config.py +++ b/src/llama_stack/providers/remote/inference/bedrock/config.py @@ -4,8 +4,29 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +import os + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig -class BedrockConfig(BedrockBaseConfig): - pass +class BedrockProviderDataValidator(BaseModel): + aws_bedrock_api_key: str | None = Field( + default=None, + description="API key for Amazon Bedrock", + ) + + +class BedrockConfig(RemoteInferenceProviderConfig): + region_name: str = Field( + default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"), + description="AWS Region for the Bedrock Runtime endpoint", + ) + + @classmethod + def sample_run_config(cls, **kwargs): + return { + "api_key": "${env.AWS_BEDROCK_API_KEY:=}", + "region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}", + } diff --git a/src/llama_stack/providers/remote/inference/bedrock/models.py b/src/llama_stack/providers/remote/inference/bedrock/models.py deleted file mode 100644 index 17273c122..000000000 --- a/src/llama_stack/providers/remote/inference/bedrock/models.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - - -# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta.llama3-1-8b-instruct-v1:0", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-70b-instruct-v1:0", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-405b-instruct-v1:0", - CoreModelId.llama3_1_405b_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/tests/unit/providers/inference/test_bedrock_adapter.py b/tests/unit/providers/inference/test_bedrock_adapter.py new file mode 100644 index 000000000..fdd07c032 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_adapter.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from openai import AuthenticationError + +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_adapter_initialization(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.config.auth_credential.get_secret_value() == "test-key" + assert adapter.config.region_name == "us-east-1" + + +def test_client_url_construction(): + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1" + + +def test_api_key_from_config(): + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.config.auth_credential.get_secret_value() == "config-key" + + +def test_api_key_from_header_overrides_config(): + """Test API key from request header overrides config via client property""" + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + adapter.provider_data_api_key_field = "aws_bedrock_api_key" + adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key")) + + # The client property is where header override happens (in OpenAIMixin) + assert adapter.client.api_key == "header-key" + + +async def test_authentication_error_handling(): + """Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message""" + config = BedrockConfig(api_key="invalid-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + # Mock the parent class method to raise AuthenticationError + mock_response = MagicMock() + mock_response.message = "Invalid authentication credentials" + auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None) + + # Create a mock that raises the error + mock_super = AsyncMock(side_effect=auth_error) + + # Patch the parent class method + original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super + + try: + with pytest.raises(ValueError) as exc_info: + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[{"role": "user", "content": "test"}] + ) + await adapter.openai_chat_completion(params=params) + + assert "AWS Bedrock authentication failed" in str(exc_info.value) + assert "Please verify your API key" in str(exc_info.value) + finally: + # Restore original method + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method diff --git a/tests/unit/providers/inference/test_bedrock_config.py b/tests/unit/providers/inference/test_bedrock_config.py new file mode 100644 index 000000000..4c1fd56a2 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_config.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_bedrock_config_defaults_no_env(monkeypatch): + """Test BedrockConfig defaults when env vars are not set""" + monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False) + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + config = BedrockConfig() + assert config.auth_credential is None + assert config.region_name == "us-east-2" + + +def test_bedrock_config_reads_from_env(monkeypatch): + """Test BedrockConfig field initialization reads from environment variables""" + monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1") + config = BedrockConfig() + assert config.region_name == "eu-west-1" + + +def test_bedrock_config_with_values(): + """Test BedrockConfig accepts explicit values via alias""" + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + assert config.auth_credential.get_secret_value() == "test-key" + assert config.region_name == "us-west-2" + + +def test_bedrock_config_sample(): + """Test BedrockConfig sample_run_config returns correct format""" + sample = BedrockConfig.sample_run_config() + assert "api_key" in sample + assert "region_name" in sample + assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}" + assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}" diff --git a/tests/unit/providers/test_bedrock.py b/tests/unit/providers/test_bedrock.py index 1ff07bbbe..684fcf262 100644 --- a/tests/unit/providers/test_bedrock.py +++ b/tests/unit/providers/test_bedrock.py @@ -4,50 +4,66 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.remote.inference.bedrock.bedrock import ( - _get_region_prefix, - _to_inference_profile_id, -) +from types import SimpleNamespace +from unittest.mock import AsyncMock, PropertyMock, patch + +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -def test_region_prefixes(): - assert _get_region_prefix("us-east-1") == "us." - assert _get_region_prefix("eu-west-1") == "eu." - assert _get_region_prefix("ap-south-1") == "ap." - assert _get_region_prefix("ca-central-1") == "us." +def test_can_create_adapter(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) - # Test case insensitive - assert _get_region_prefix("US-EAST-1") == "us." - assert _get_region_prefix("EU-WEST-1") == "eu." - assert _get_region_prefix("Ap-South-1") == "ap." - - # Test None region - assert _get_region_prefix(None) == "us." + assert adapter is not None + assert adapter.config.region_name == "us-east-1" + assert adapter.get_api_key() == "test-key" -def test_model_id_conversion(): - # Basic conversion - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0" +def test_different_aws_regions(): + # just check a couple regions to verify URL construction works + config = BedrockConfig(api_key="key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1" + + config = BedrockConfig(api_key="key", region_name="eu-west-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1" + + +async def test_basic_chat_completion(): + """Test basic chat completion works with OpenAIMixin""" + config = BedrockConfig(api_key="k", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + class FakeModelStore: + async def has_model(self, model_id): + return True + + async def get_model(self, model_id): + return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0") + + adapter.model_store = FakeModelStore() + + fake_response = SimpleNamespace( + id="chatcmpl-123", + choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")], ) - # Already has prefix - assert ( - _to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1") - == "us.meta.llama3-1-70b-instruct-v1:0" - ) + mock_create = AsyncMock(return_value=fake_response) - # ARN should be returned unchanged - arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0" - assert _to_inference_profile_id(arn, "us-east-1") == arn + class FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create)) - # ARN should be returned unchanged even without region - assert _to_inference_profile_id(arn) == arn + with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()): + params = OpenAIChatCompletionRequestWithExtraBody( + model="llama3-1-8b", + messages=[{"role": "user", "content": "hello"}], + stream=False, + ) + response = await adapter.openai_chat_completion(params=params) - # Optional region parameter defaults to us-east-1 - assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0" - - # Different regions work with optional parameter - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0" - ) + assert response.id == "chatcmpl-123" + assert mock_create.await_count == 1