mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: add OpenAI-compatible Bedrock provider with error handling
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration
This commit is contained in:
parent
c899b50723
commit
4ff367251f
12 changed files with 288 additions and 187 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
---
|
---
|
||||||
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
|
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
|
||||||
sidebar_label: Remote - Bedrock
|
sidebar_label: Remote - Bedrock
|
||||||
title: remote::bedrock
|
title: remote::bedrock
|
||||||
---
|
---
|
||||||
|
|
@ -8,7 +8,7 @@ title: remote::bedrock
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
|
AWS Bedrock inference provider using OpenAI compatible endpoint.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
|
|
@ -16,19 +16,12 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
|
||||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
|
||||||
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
|
||||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
|
||||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
|
||||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
|
||||||
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
|
||||||
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
|
||||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
{}
|
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||||
|
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,9 @@ providers:
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||||
|
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,9 @@ providers:
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||||
|
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,9 @@ providers:
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
config:
|
||||||
|
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||||
|
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_type="bedrock",
|
adapter_type="bedrock",
|
||||||
provider_type="remote::bedrock",
|
provider_type="remote::bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.bedrock",
|
module="llama_stack.providers.remote.inference.bedrock",
|
||||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
|
||||||
|
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||||
|
|
||||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = BedrockInferenceAdapter(config)
|
impl = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,139 +4,106 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
from collections.abc import AsyncIterator, Iterable
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from openai import AuthenticationError
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
Model,
|
||||||
Inference,
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletion,
|
||||||
OpenAICompletionRequestWithExtraBody,
|
OpenAICompletionRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
OpenAIChatCompletion,
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
|
||||||
ModelRegistryHelper,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
get_sampling_strategy_options,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .config import BedrockConfig
|
||||||
|
|
||||||
REGION_PREFIX_MAP = {
|
|
||||||
"us": "us.",
|
|
||||||
"eu": "eu.",
|
|
||||||
"ap": "ap.",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_region_prefix(region: str | None) -> str:
|
class BedrockInferenceAdapter(OpenAIMixin):
|
||||||
# AWS requires region prefixes for inference profiles
|
"""
|
||||||
if region is None:
|
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.
|
||||||
return "us." # default to US when we don't know
|
|
||||||
|
|
||||||
# Handle case insensitive region matching
|
Supports Llama models across regions and GPT-OSS models (us-west-2 only).
|
||||||
region_lower = region.lower()
|
|
||||||
for prefix in REGION_PREFIX_MAP:
|
|
||||||
if region_lower.startswith(f"{prefix}-"):
|
|
||||||
return REGION_PREFIX_MAP[prefix]
|
|
||||||
|
|
||||||
# Fallback to US for anything we don't recognize
|
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
|
||||||
return "us."
|
for dynamic model discovery. Models must be pre-registered in the config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: BedrockConfig
|
||||||
|
provider_data_api_key_field: str = "aws_bedrock_api_key"
|
||||||
|
|
||||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
def get_api_key(self) -> str:
|
||||||
# Return ARNs unchanged
|
"""Get API key for OpenAI client."""
|
||||||
if model_id.startswith("arn:"):
|
if not self.config.api_key:
|
||||||
return model_id
|
raise ValueError(
|
||||||
|
"API key is not set. Please provide a valid API key in the "
|
||||||
|
"provider config or via AWS_BEDROCK_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
# Return inference profile IDs that already have regional prefixes
|
def get_base_url(self) -> str:
|
||||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
"""Get base URL for OpenAI client."""
|
||||||
return model_id
|
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
|
||||||
|
|
||||||
# Default to US East when no region is provided
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
if region is None:
|
"""
|
||||||
region = "us-east-1"
|
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
|
||||||
|
Returns empty list since models must be pre-registered in the config.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
return _get_region_prefix(region) + model_id
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
"""
|
||||||
|
Register a model with the Bedrock provider.
|
||||||
|
|
||||||
|
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
|
||||||
class BedrockInferenceAdapter(
|
the availability check and accept all models registered in the config.
|
||||||
ModelRegistryHelper,
|
"""
|
||||||
Inference,
|
return model
|
||||||
):
|
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
|
||||||
self._config = config
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> BaseClient:
|
|
||||||
if self._client is None:
|
|
||||||
self._client = create_bedrock_client(self._config)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
if self._client is not None:
|
|
||||||
self._client.close()
|
|
||||||
|
|
||||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
bedrock_model = request.model
|
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
|
||||||
options = get_sampling_strategy_options(sampling_params)
|
|
||||||
|
|
||||||
if sampling_params.max_tokens:
|
|
||||||
options["max_gen_len"] = sampling_params.max_tokens
|
|
||||||
if sampling_params.repetition_penalty > 0:
|
|
||||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
|
||||||
|
|
||||||
# Convert foundation model ID to inference profile ID
|
|
||||||
region_name = self.client.meta.region_name
|
|
||||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"modelId": inference_profile_id,
|
|
||||||
"body": json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": prompt,
|
|
||||||
**options,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
|
||||||
|
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
|
||||||
|
"Only /v1/chat/completions is supported. "
|
||||||
|
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
"""Override to enable streaming usage metrics and handle authentication errors."""
|
||||||
|
# Enable streaming usage metrics when telemetry is active
|
||||||
|
if params.stream and get_current_span() is not None:
|
||||||
|
if params.stream_options is None:
|
||||||
|
params.stream_options = {"include_usage": True}
|
||||||
|
elif "include_usage" not in params.stream_options:
|
||||||
|
params.stream_options = {**params.stream_options, "include_usage": True}
|
||||||
|
|
||||||
|
# Wrap call in try/except to catch authentication errors
|
||||||
|
try:
|
||||||
|
return await super().openai_chat_completion(params=params)
|
||||||
|
except AuthenticationError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"AWS Bedrock authentication failed: {e.message}. "
|
||||||
|
"Please check your API key in the provider config or x-llamastack-provider-data header."
|
||||||
|
) from e
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,33 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
import os
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
|
||||||
|
|
||||||
class BedrockConfig(BedrockBaseConfig):
|
class BedrockProviderDataValidator(BaseModel):
|
||||||
pass
|
aws_bedrock_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Amazon Bedrock",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockConfig(RemoteInferenceProviderConfig):
|
||||||
|
api_key: str | None = Field(
|
||||||
|
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
|
||||||
|
description="Amazon Bedrock API key",
|
||||||
|
)
|
||||||
|
region_name: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
|
||||||
|
description="AWS Region for the Bedrock Runtime endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs):
|
||||||
|
return {
|
||||||
|
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
|
||||||
|
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
|
||||||
build_hf_repo_model_entry,
|
|
||||||
)
|
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = []
|
|
||||||
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
|
||||||
MODEL_ENTRIES = [
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
|
||||||
CoreModelId.llama3_1_70b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
|
||||||
CoreModelId.llama3_1_405b_instruct.value,
|
|
||||||
),
|
|
||||||
] + SAFETY_MODELS_ENTRIES
|
|
||||||
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import AuthenticationError
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_adapter_initialization():
|
||||||
|
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
assert adapter.config.api_key == "test-key"
|
||||||
|
assert adapter.config.region_name == "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_url_construction():
|
||||||
|
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
|
||||||
|
assert adapter.get_api_key() == "test-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_from_config():
|
||||||
|
"""Test API key is read from config"""
|
||||||
|
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
assert adapter.get_api_key() == "config-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_from_header_overrides_config():
|
||||||
|
"""Test API key from request header overrides config via client property"""
|
||||||
|
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
adapter.provider_data_api_key_field = "aws_bedrock_api_key"
|
||||||
|
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key"))
|
||||||
|
|
||||||
|
# The client property is where header override happens (in OpenAIMixin)
|
||||||
|
assert adapter.client.api_key == "header-key"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_authentication_error_handling():
|
||||||
|
"""Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message"""
|
||||||
|
config = BedrockConfig(api_key="invalid-key", region_name="us-east-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
# Mock the parent class method to raise AuthenticationError
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message = "Invalid authentication credentials"
|
||||||
|
auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None)
|
||||||
|
|
||||||
|
# Create a mock that raises the error
|
||||||
|
mock_super = AsyncMock(side_effect=auth_error)
|
||||||
|
|
||||||
|
# Patch the parent class method
|
||||||
|
original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion
|
||||||
|
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="test-model", messages=[{"role": "user", "content": "test"}]
|
||||||
|
)
|
||||||
|
await adapter.openai_chat_completion(params=params)
|
||||||
|
|
||||||
|
assert "AWS Bedrock authentication failed" in str(exc_info.value)
|
||||||
|
assert "Please check your API key" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
# Restore original method
|
||||||
|
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method
|
||||||
41
tests/unit/providers/inference/test_bedrock_config.py
Normal file
41
tests/unit/providers/inference/test_bedrock_config.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_config_defaults_no_env(monkeypatch):
|
||||||
|
"""Test BedrockConfig defaults when env vars are not set"""
|
||||||
|
monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
|
||||||
|
config = BedrockConfig()
|
||||||
|
assert config.api_key is None
|
||||||
|
assert config.region_name == "us-east-2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_config_defaults_with_env(monkeypatch):
|
||||||
|
"""Test BedrockConfig reads from environment variables"""
|
||||||
|
monkeypatch.setenv("AWS_BEDROCK_API_KEY", "env-key")
|
||||||
|
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||||
|
config = BedrockConfig()
|
||||||
|
assert config.api_key == "env-key"
|
||||||
|
assert config.region_name == "eu-west-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_config_with_values():
|
||||||
|
"""Test BedrockConfig accepts explicit values"""
|
||||||
|
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||||
|
assert config.api_key == "test-key"
|
||||||
|
assert config.region_name == "us-west-2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_config_sample():
|
||||||
|
"""Test BedrockConfig sample_run_config returns correct format"""
|
||||||
|
sample = BedrockConfig.sample_run_config()
|
||||||
|
assert "api_key" in sample
|
||||||
|
assert "region_name" in sample
|
||||||
|
assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}"
|
||||||
|
assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}"
|
||||||
|
|
@ -4,50 +4,63 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.bedrock import (
|
from types import SimpleNamespace
|
||||||
_get_region_prefix,
|
from unittest.mock import AsyncMock, PropertyMock, patch
|
||||||
_to_inference_profile_id,
|
|
||||||
)
|
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
def test_region_prefixes():
|
def test_can_create_adapter():
|
||||||
assert _get_region_prefix("us-east-1") == "us."
|
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||||
assert _get_region_prefix("eu-west-1") == "eu."
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
assert _get_region_prefix("ap-south-1") == "ap."
|
|
||||||
assert _get_region_prefix("ca-central-1") == "us."
|
|
||||||
|
|
||||||
# Test case insensitive
|
assert adapter is not None
|
||||||
assert _get_region_prefix("US-EAST-1") == "us."
|
assert adapter.config.region_name == "us-east-1"
|
||||||
assert _get_region_prefix("EU-WEST-1") == "eu."
|
assert adapter.get_api_key() == "test-key"
|
||||||
assert _get_region_prefix("Ap-South-1") == "ap."
|
|
||||||
|
|
||||||
# Test None region
|
|
||||||
assert _get_region_prefix(None) == "us."
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_id_conversion():
|
def test_different_aws_regions():
|
||||||
# Basic conversion
|
# just check a couple regions to verify URL construction works
|
||||||
assert (
|
config = BedrockConfig(api_key="key", region_name="us-east-1")
|
||||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1"
|
||||||
|
|
||||||
|
config = BedrockConfig(api_key="key", region_name="eu-west-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_basic_chat_completion():
|
||||||
|
"""Test basic chat completion works with OpenAIMixin"""
|
||||||
|
config = BedrockConfig(api_key="k", region_name="us-east-1")
|
||||||
|
adapter = BedrockInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
class FakeModelStore:
|
||||||
|
async def get_model(self, model_id):
|
||||||
|
return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0")
|
||||||
|
|
||||||
|
adapter.model_store = FakeModelStore()
|
||||||
|
|
||||||
|
fake_response = SimpleNamespace(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Already has prefix
|
mock_create = AsyncMock(return_value=fake_response)
|
||||||
assert (
|
|
||||||
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
|
|
||||||
== "us.meta.llama3-1-70b-instruct-v1:0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ARN should be returned unchanged
|
class FakeClient:
|
||||||
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
|
def __init__(self):
|
||||||
assert _to_inference_profile_id(arn, "us-east-1") == arn
|
self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create))
|
||||||
|
|
||||||
# ARN should be returned unchanged even without region
|
with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()):
|
||||||
assert _to_inference_profile_id(arn) == arn
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="llama3-1-8b",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = await adapter.openai_chat_completion(params=params)
|
||||||
|
|
||||||
# Optional region parameter defaults to us-east-1
|
assert response.id == "chatcmpl-123"
|
||||||
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
|
assert mock_create.await_count == 1
|
||||||
|
|
||||||
# Different regions work with optional parameter
|
|
||||||
assert (
|
|
||||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue