fix(base_aws_llm.py): remove region name before sending in args (#8998)

* fix(base_aws_llm.py): remove region name before sending in args

* fix(base_aws_llm.py): fix optional param pop position

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-04 23:05:28 -08:00 committed by GitHub
parent 4cd35205ae
commit ecef915ac9
5 changed files with 39 additions and 2 deletions

View file

@ -554,6 +554,7 @@ class BaseAWSLLM:
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = self._get_aws_region_name(optional_params, model) aws_region_name = self._get_aws_region_name(optional_params, model)
optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None) aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None) aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None)

View file

@ -10,6 +10,8 @@ import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
@ -51,6 +53,7 @@ class BedrockImageGeneration(BaseAWSLLM):
aimg_generation: bool = False, aimg_generation: bool = False,
api_base: Optional[str] = None, api_base: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
model=model, model=model,
@ -69,9 +72,15 @@ class BedrockImageGeneration(BaseAWSLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
prompt=prompt, prompt=prompt,
model_response=model_response, model_response=model_response,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
) )
client = _get_httpx_client() if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try: try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status() response.raise_for_status()
@ -99,13 +108,14 @@ class BedrockImageGeneration(BaseAWSLLM):
logging_obj: LitellmLogging, logging_obj: LitellmLogging,
prompt: str, prompt: str,
model_response: ImageResponse, model_response: ImageResponse,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse: ) -> ImageResponse:
""" """
Asynchronous handler for bedrock image generation Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint Awaits the response from the bedrock image generation endpoint
""" """
async_client = get_async_httpx_client( async_client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK, llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout}, params={"timeout": timeout},
) )

View file

@ -4521,6 +4521,7 @@ def image_generation( # noqa: PLR0915
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen( optional_params = get_optional_params_image_gen(
model=model, model=model,
n=n, n=n,
@ -4532,6 +4533,7 @@ def image_generation( # noqa: PLR0915
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
**non_default_params, **non_default_params,
) )
logging: Logging = litellm_logging_obj logging: Logging = litellm_logging_obj
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -4630,6 +4632,7 @@ def image_generation( # noqa: PLR0915
optional_params=optional_params, optional_params=optional_params,
model_response=model_response, model_response=model_response,
aimg_generation=aimg_generation, aimg_generation=aimg_generation,
client=client,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (

View file

@ -11,6 +11,7 @@ import uuid
from typing import TYPE_CHECKING, List, Optional, Union, cast from typing import TYPE_CHECKING, List, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger

View file

@ -263,3 +263,25 @@ def test_cost_calculator_basic():
assert isinstance(cost, float) assert isinstance(cost, float)
assert cost > 0 assert cost > 0
def test_bedrock_image_gen_with_aws_region_name():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm import image_generation
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
image_generation(
model="bedrock/stability.stable-image-ultra-v1:1",
prompt="A beautiful sunset",
aws_region_name="us-west-2",
client=client,
)
except Exception as e:
print(e)
raise e
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
print(kwargs)