fix(base_aws_llm.py): remove region name before sending in args (#8998)

* fix(base_aws_llm.py): remove region name before sending in args

* fix(base_aws_llm.py): fix optional param pop position

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-04 23:05:28 -08:00 committed by GitHub
parent 4cd35205ae
commit ecef915ac9
5 changed files with 39 additions and 2 deletions

View file

@ -554,6 +554,7 @@ class BaseAWSLLM:
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = self._get_aws_region_name(optional_params, model)
optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)

View file

@ -10,6 +10,8 @@ import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
@ -51,6 +53,7 @@ class BedrockImageGeneration(BaseAWSLLM):
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
prepared_request = self._prepare_request(
model=model,
@ -69,9 +72,15 @@ class BedrockImageGeneration(BaseAWSLLM):
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
client = _get_httpx_client()
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
@ -99,13 +108,14 @@ class BedrockImageGeneration(BaseAWSLLM):
logging_obj: LitellmLogging,
prompt: str,
model_response: ImageResponse,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint
"""
async_client = get_async_httpx_client(
async_client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)

View file

@ -4521,6 +4521,7 @@ def image_generation( # noqa: PLR0915
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
model=model,
n=n,
@ -4532,6 +4533,7 @@ def image_generation( # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
@ -4630,6 +4632,7 @@ def image_generation( # noqa: PLR0915
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (

View file

@ -11,6 +11,7 @@ import uuid
from typing import TYPE_CHECKING, List, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger

View file

@ -263,3 +263,25 @@ def test_cost_calculator_basic():
assert isinstance(cost, float)
assert cost > 0
def test_bedrock_image_gen_with_aws_region_name():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm import image_generation
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
image_generation(
model="bedrock/stability.stable-image-ultra-v1:1",
prompt="A beautiful sunset",
aws_region_name="us-west-2",
client=client,
)
except Exception as e:
print(e)
raise e
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
print(kwargs)