llama-stack-mirror/llama_stack/providers/remote/inference/bedrock/bedrock.py
Eric Huang a93130e323 test
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
Completes the refactoring started in previous commit by:

1. **Fix library client** (critical): Add logic to detect Pydantic model parameters
   and construct them properly from request bodies. The key fix is to NOT exclude
   any params when converting the body for Pydantic models - we need all fields
   to pass to the Pydantic constructor.

   Before: _convert_body excluded all params, leaving body empty for Pydantic construction
   After: Check for Pydantic params first, skip exclusion, construct model with full body

2. **Update remaining providers** to use new Pydantic-based signatures:
   - litellm_openai_mixin: Extract extra fields via __pydantic_extra__
   - databricks: Use TYPE_CHECKING import for params type
   - llama_openai_compat: Use TYPE_CHECKING import for params type
   - sentence_transformers: Update method signatures to use params

3. **Update unit tests** to use new Pydantic signature:
   - test_openai_mixin.py: Use OpenAIChatCompletionRequestParams

This fixes test failures where the library client was trying to construct
Pydantic models with empty dictionaries.
The previous fix had a bug: it called _convert_body() which only keeps fields
that match function parameter names. For Pydantic methods with signature:
  openai_chat_completion(params: OpenAIChatCompletionRequestParams)

The signature only has 'params', but the body has 'model', 'messages', etc.
So _convert_body() returned an empty dict.

Fix: Skip _convert_body() entirely for Pydantic params. Use the raw body
directly to construct the Pydantic model (after stripping NOT_GIVENs).

This properly fixes the ValidationError where required fields were missing.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
2025-10-09 13:53:33 -07:00

148 lines
4.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletionRequestParams,
OpenAICompletionRequestParams,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .models import MODEL_ENTRIES
REGION_PREFIX_MAP = {
"us": "us.",
"eu": "eu.",
"ap": "ap.",
}
def _get_region_prefix(region: str | None) -> str:
# AWS requires region prefixes for inference profiles
if region is None:
return "us." # default to US when we don't know
# Handle case insensitive region matching
region_lower = region.lower()
for prefix in REGION_PREFIX_MAP:
if region_lower.startswith(f"{prefix}-"):
return REGION_PREFIX_MAP[prefix]
# Fallback to US for anything we don't recognize
return "us."
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
# Return ARNs unchanged
if model_id.startswith("arn:"):
return model_id
# Return inference profile IDs that already have regional prefixes
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
return model_id
# Default to US East when no region is provided
if region is None:
region = "us-east-1"
return _get_region_prefix(region) + model_id
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None
@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self._client is not None:
self._client.close()
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
sampling_params = request.sampling_params
options = get_sampling_strategy_options(sampling_params)
if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
return {
"modelId": inference_profile_id,
"body": json.dumps(
{
"prompt": prompt,
**options,
}
),
}
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestParams,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestParams,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")