mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: refactor auth and improve error handling for Bedrock provider
Refactor to use auth_credential for consistent credential management and improve error handling with defensive checks. Changes: - Use auth_credential instead of api_key for better credential handling - Simplify model availability check to accept all pre-registered models - Guard metrics collection when usage data is missing in responses - Add debug logging for better troubleshooting of API issues - Update unit tests for auth_credential refactoring
This commit is contained in:
parent
dc27537cce
commit
454aeaaf3e
6 changed files with 51 additions and 40 deletions
|
|
@ -16,7 +16,7 @@ AWS Bedrock inference provider using OpenAI compatible endpoint.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry_enabled:
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
|
|
@ -253,7 +253,7 @@ class InferenceRouter(Inference):
|
|||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry_enabled:
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from collections.abc import AsyncIterator, Iterable
|
|||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Model,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
|
|
@ -40,15 +39,6 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
config: BedrockConfig
|
||||
provider_data_api_key_field: str = "aws_bedrock_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
if not self.config.api_key:
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider config or via AWS_BEDROCK_API_KEY environment variable."
|
||||
)
|
||||
return self.config.api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
|
||||
|
|
@ -60,14 +50,12 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
"""
|
||||
return []
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Register a model with the Bedrock provider.
|
||||
|
||||
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
|
||||
the availability check and accept all models registered in the config.
|
||||
Bedrock doesn't support dynamic model listing via /v1/models.
|
||||
Always return True to accept all models registered in the config.
|
||||
"""
|
||||
return model
|
||||
return True
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -102,11 +90,40 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
elif "include_usage" not in params.stream_options:
|
||||
params.stream_options = {**params.stream_options, "include_usage": True}
|
||||
|
||||
# Wrap call in try/except to catch authentication errors
|
||||
try:
|
||||
return await super().openai_chat_completion(params=params)
|
||||
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
|
||||
result = await super().openai_chat_completion(params=params)
|
||||
logger.debug(f"Bedrock API returned: {type(result).__name__ if result is not None else 'None'}")
|
||||
|
||||
# Defensive check for unexpected None response
|
||||
if result is None:
|
||||
logger.error(f"OpenAI client returned None for model={params.model}, stream={params.stream}")
|
||||
raise RuntimeError(
|
||||
f"Bedrock API returned no response for model '{params.model}'. "
|
||||
"This may indicate the model is not supported or a network/API issue occurred."
|
||||
)
|
||||
|
||||
return result
|
||||
except AuthenticationError as e:
|
||||
raise ValueError(
|
||||
f"AWS Bedrock authentication failed: {e.message}. "
|
||||
"Please check your API key in the provider config or x-llamastack-provider-data header."
|
||||
) from e
|
||||
# Extract detailed error message from the exception
|
||||
error_msg = str(e)
|
||||
|
||||
# Check if this is a token expiration error
|
||||
if "expired" in error_msg.lower() or "Bearer Token has expired" in error_msg:
|
||||
logger.error(f"AWS Bedrock authentication token expired: {error_msg}")
|
||||
raise ValueError(
|
||||
"AWS Bedrock authentication failed: Bearer token has expired. "
|
||||
"The AWS_BEDROCK_API_KEY environment variable contains an expired pre-signed URL. "
|
||||
"Please refresh your token by generating a new pre-signed URL with AWS credentials. "
|
||||
"Refer to AWS Bedrock documentation for details on OpenAI-compatible endpoints."
|
||||
) from e
|
||||
else:
|
||||
logger.error(f"AWS Bedrock authentication failed: {error_msg}")
|
||||
raise ValueError(
|
||||
f"AWS Bedrock authentication failed: {error_msg}. "
|
||||
"Please verify your API key is correct in the provider config or x-llamastack-provider-data header. "
|
||||
"The API key should be a valid AWS pre-signed URL for Bedrock's OpenAI-compatible endpoint."
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling Bedrock API: {type(e).__name__}: {e}", exc_info=True)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@ class BedrockProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
class BedrockConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
|
||||
description="Amazon Bedrock API key",
|
||||
)
|
||||
region_name: str = Field(
|
||||
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
|
||||
description="AWS Region for the Bedrock Runtime endpoint",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def test_adapter_initialization():
|
|||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.config.api_key == "test-key"
|
||||
assert adapter.config.auth_credential.get_secret_value() == "test-key"
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
|
||||
|
||||
|
|
@ -28,15 +28,15 @@ def test_client_url_construction():
|
|||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
|
||||
assert adapter.get_api_key() == "test-key"
|
||||
|
||||
|
||||
def test_api_key_from_config():
|
||||
"""Test API key is read from config"""
|
||||
"""Test API key is stored as SecretStr in auth_credential"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_api_key() == "config-key"
|
||||
# API key is stored in auth_credential field (SecretStr)
|
||||
assert adapter.config.auth_credential.get_secret_value() == "config-key"
|
||||
|
||||
|
||||
def test_api_key_from_header_overrides_config():
|
||||
|
|
|
|||
|
|
@ -12,23 +12,21 @@ def test_bedrock_config_defaults_no_env(monkeypatch):
|
|||
monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
|
||||
config = BedrockConfig()
|
||||
assert config.api_key is None
|
||||
assert config.auth_credential is None
|
||||
assert config.region_name == "us-east-2"
|
||||
|
||||
|
||||
def test_bedrock_config_defaults_with_env(monkeypatch):
|
||||
"""Test BedrockConfig reads from environment variables"""
|
||||
monkeypatch.setenv("AWS_BEDROCK_API_KEY", "env-key")
|
||||
def test_bedrock_config_reads_from_env(monkeypatch):
|
||||
"""Test BedrockConfig field initialization reads from environment variables"""
|
||||
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||
config = BedrockConfig()
|
||||
assert config.api_key == "env-key"
|
||||
assert config.region_name == "eu-west-1"
|
||||
|
||||
|
||||
def test_bedrock_config_with_values():
|
||||
"""Test BedrockConfig accepts explicit values"""
|
||||
"""Test BedrockConfig accepts explicit values via alias"""
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
assert config.api_key == "test-key"
|
||||
assert config.auth_credential.get_secret_value() == "test-key"
|
||||
assert config.region_name == "us-west-2"
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue