mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: add OpenAI-compatible Bedrock provider with error handling
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration
This commit is contained in:
parent
c899b50723
commit
4ff367251f
12 changed files with 288 additions and 187 deletions
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
81
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
def test_adapter_initialization():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.config.api_key == "test-key"
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
|
||||
|
||||
def test_client_url_construction():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
|
||||
assert adapter.get_api_key() == "test-key"
|
||||
|
||||
|
||||
def test_api_key_from_config():
|
||||
"""Test API key is read from config"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_api_key() == "config-key"
|
||||
|
||||
|
||||
def test_api_key_from_header_overrides_config():
|
||||
"""Test API key from request header overrides config via client property"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = "aws_bedrock_api_key"
|
||||
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key"))
|
||||
|
||||
# The client property is where header override happens (in OpenAIMixin)
|
||||
assert adapter.client.api_key == "header-key"
|
||||
|
||||
|
||||
async def test_authentication_error_handling():
|
||||
"""Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message"""
|
||||
config = BedrockConfig(api_key="invalid-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
# Mock the parent class method to raise AuthenticationError
|
||||
mock_response = MagicMock()
|
||||
mock_response.message = "Invalid authentication credentials"
|
||||
auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None)
|
||||
|
||||
# Create a mock that raises the error
|
||||
mock_super = AsyncMock(side_effect=auth_error)
|
||||
|
||||
# Patch the parent class method
|
||||
original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[{"role": "user", "content": "test"}]
|
||||
)
|
||||
await adapter.openai_chat_completion(params=params)
|
||||
|
||||
assert "AWS Bedrock authentication failed" in str(exc_info.value)
|
||||
assert "Please check your API key" in str(exc_info.value)
|
||||
finally:
|
||||
# Restore original method
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method
|
||||
Loading…
Add table
Add a link
Reference in a new issue