mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: add OpenAI-compatible Bedrock provider with error handling
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration
This commit is contained in:
parent
c899b50723
commit
4ff367251f
12 changed files with 288 additions and 187 deletions
|
|
@ -1,5 +1,5 @@
|
|||
---
|
||||
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
|
||||
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
|
||||
sidebar_label: Remote - Bedrock
|
||||
title: remote::bedrock
|
||||
---
|
||||
|
|
@ -8,7 +8,7 @@ title: remote::bedrock
|
|||
|
||||
## Description
|
||||
|
||||
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
|
||||
AWS Bedrock inference provider using OpenAI compatible endpoint.
|
||||
|
||||
## Configuration
|
||||
|
||||
|
|
@ -16,19 +16,12 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
||||
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
||||
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
|
||||
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
{}
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
|
||||
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
|
|||
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
impl = BedrockInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,139 +4,106 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
Inference,
|
||||
Model,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.
|
||||
|
||||
Supports Llama models across regions and GPT-OSS models (us-west-2 only).
|
||||
|
||||
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
|
||||
for dynamic model discovery. Models must be pre-registered in the config.
|
||||
"""
|
||||
|
||||
config: BedrockConfig
|
||||
provider_data_api_key_field: str = "aws_bedrock_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
if not self.config.api_key:
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider config or via AWS_BEDROCK_API_KEY environment variable."
|
||||
)
|
||||
return self.config.api_key
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
|
||||
Returns empty list since models must be pre-registered in the config.
|
||||
"""
|
||||
return []
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Register a model with the Bedrock provider.
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
options = get_sampling_strategy_options(sampling_params)
|
||||
|
||||
if sampling_params.max_tokens:
|
||||
options["max_gen_len"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**options,
|
||||
}
|
||||
),
|
||||
}
|
||||
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
|
||||
the availability check and accept all models registered in the config.
|
||||
"""
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
|
||||
raise NotImplementedError(
|
||||
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
|
||||
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
|
||||
raise NotImplementedError(
|
||||
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
|
||||
"Only /v1/chat/completions is supported. "
|
||||
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
"""Override to enable streaming usage metrics and handle authentication errors."""
|
||||
# Enable streaming usage metrics when telemetry is active
|
||||
if params.stream and get_current_span() is not None:
|
||||
if params.stream_options is None:
|
||||
params.stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in params.stream_options:
|
||||
params.stream_options = {**params.stream_options, "include_usage": True}
|
||||
|
||||
# Wrap call in try/except to catch authentication errors
|
||||
try:
|
||||
return await super().openai_chat_completion(params=params)
|
||||
except AuthenticationError as e:
|
||||
raise ValueError(
|
||||
f"AWS Bedrock authentication failed: {e.message}. "
|
||||
"Please check your API key in the provider config or x-llamastack-provider-data header."
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -4,8 +4,33 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
class BedrockProviderDataValidator(BaseModel):
|
||||
aws_bedrock_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Amazon Bedrock",
|
||||
)
|
||||
|
||||
|
||||
class BedrockConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
|
||||
description="Amazon Bedrock API key",
|
||||
)
|
||||
region_name: str = Field(
|
||||
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
|
||||
description="AWS Region for the Bedrock Runtime endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {
|
||||
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
|
||||
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
def test_adapter_initialization():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.config.api_key == "test-key"
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
|
||||
|
||||
def test_client_url_construction():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
|
||||
assert adapter.get_api_key() == "test-key"
|
||||
|
||||
|
||||
def test_api_key_from_config():
|
||||
"""Test API key is read from config"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_api_key() == "config-key"
|
||||
|
||||
|
||||
def test_api_key_from_header_overrides_config():
|
||||
"""Test API key from request header overrides config via client property"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = "aws_bedrock_api_key"
|
||||
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key"))
|
||||
|
||||
# The client property is where header override happens (in OpenAIMixin)
|
||||
assert adapter.client.api_key == "header-key"
|
||||
|
||||
|
||||
async def test_authentication_error_handling():
|
||||
"""Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message"""
|
||||
config = BedrockConfig(api_key="invalid-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
# Mock the parent class method to raise AuthenticationError
|
||||
mock_response = MagicMock()
|
||||
mock_response.message = "Invalid authentication credentials"
|
||||
auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None)
|
||||
|
||||
# Create a mock that raises the error
|
||||
mock_super = AsyncMock(side_effect=auth_error)
|
||||
|
||||
# Patch the parent class method
|
||||
original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[{"role": "user", "content": "test"}]
|
||||
)
|
||||
await adapter.openai_chat_completion(params=params)
|
||||
|
||||
assert "AWS Bedrock authentication failed" in str(exc_info.value)
|
||||
assert "Please check your API key" in str(exc_info.value)
|
||||
finally:
|
||||
# Restore original method
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method
|
||||
41
tests/unit/providers/inference/test_bedrock_config.py
Normal file
41
tests/unit/providers/inference/test_bedrock_config.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
def test_bedrock_config_defaults_no_env(monkeypatch):
|
||||
"""Test BedrockConfig defaults when env vars are not set"""
|
||||
monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
|
||||
config = BedrockConfig()
|
||||
assert config.api_key is None
|
||||
assert config.region_name == "us-east-2"
|
||||
|
||||
|
||||
def test_bedrock_config_defaults_with_env(monkeypatch):
|
||||
"""Test BedrockConfig reads from environment variables"""
|
||||
monkeypatch.setenv("AWS_BEDROCK_API_KEY", "env-key")
|
||||
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||
config = BedrockConfig()
|
||||
assert config.api_key == "env-key"
|
||||
assert config.region_name == "eu-west-1"
|
||||
|
||||
|
||||
def test_bedrock_config_with_values():
|
||||
"""Test BedrockConfig accepts explicit values"""
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
assert config.api_key == "test-key"
|
||||
assert config.region_name == "us-west-2"
|
||||
|
||||
|
||||
def test_bedrock_config_sample():
|
||||
"""Test BedrockConfig sample_run_config returns correct format"""
|
||||
sample = BedrockConfig.sample_run_config()
|
||||
assert "api_key" in sample
|
||||
assert "region_name" in sample
|
||||
assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}"
|
||||
assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}"
|
||||
|
|
@ -4,50 +4,63 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import (
|
||||
_get_region_prefix,
|
||||
_to_inference_profile_id,
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, PropertyMock, patch
|
||||
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
def test_can_create_adapter():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
assert adapter.get_api_key() == "test-key"
|
||||
|
||||
|
||||
def test_different_aws_regions():
|
||||
# just check a couple regions to verify URL construction works
|
||||
config = BedrockConfig(api_key="key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1"
|
||||
|
||||
config = BedrockConfig(api_key="key", region_name="eu-west-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1"
|
||||
|
||||
|
||||
async def test_basic_chat_completion():
|
||||
"""Test basic chat completion works with OpenAIMixin"""
|
||||
config = BedrockConfig(api_key="k", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
class FakeModelStore:
|
||||
async def get_model(self, model_id):
|
||||
return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0")
|
||||
|
||||
adapter.model_store = FakeModelStore()
|
||||
|
||||
fake_response = SimpleNamespace(
|
||||
id="chatcmpl-123",
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")],
|
||||
)
|
||||
|
||||
mock_create = AsyncMock(return_value=fake_response)
|
||||
|
||||
def test_region_prefixes():
|
||||
assert _get_region_prefix("us-east-1") == "us."
|
||||
assert _get_region_prefix("eu-west-1") == "eu."
|
||||
assert _get_region_prefix("ap-south-1") == "ap."
|
||||
assert _get_region_prefix("ca-central-1") == "us."
|
||||
class FakeClient:
|
||||
def __init__(self):
|
||||
self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create))
|
||||
|
||||
# Test case insensitive
|
||||
assert _get_region_prefix("US-EAST-1") == "us."
|
||||
assert _get_region_prefix("EU-WEST-1") == "eu."
|
||||
assert _get_region_prefix("Ap-South-1") == "ap."
|
||||
|
||||
# Test None region
|
||||
assert _get_region_prefix(None) == "us."
|
||||
|
||||
|
||||
def test_model_id_conversion():
|
||||
# Basic conversion
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()):
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="llama3-1-8b",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
stream=False,
|
||||
)
|
||||
response = await adapter.openai_chat_completion(params=params)
|
||||
|
||||
# Already has prefix
|
||||
assert (
|
||||
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
|
||||
== "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
|
||||
# ARN should be returned unchanged
|
||||
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
|
||||
assert _to_inference_profile_id(arn, "us-east-1") == arn
|
||||
|
||||
# ARN should be returned unchanged even without region
|
||||
assert _to_inference_profile_id(arn) == arn
|
||||
|
||||
# Optional region parameter defaults to us-east-1
|
||||
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
|
||||
# Different regions work with optional parameter
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
assert response.id == "chatcmpl-123"
|
||||
assert mock_create.await_count == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue