diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index bf9a070f26..86b47675d4 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -554,6 +554,7 @@ class BaseAWSLLM: aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = self._get_aws_region_name(optional_params, model) + optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None) diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index 4bd63fd21b..59a80b2222 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -10,6 +10,8 @@ import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, _get_httpx_client, get_async_httpx_client, ) @@ -51,6 +53,7 @@ class BedrockImageGeneration(BaseAWSLLM): aimg_generation: bool = False, api_base: Optional[str] = None, extra_headers: Optional[dict] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): prepared_request = self._prepare_request( model=model, @@ -69,9 +72,15 @@ class BedrockImageGeneration(BaseAWSLLM): logging_obj=logging_obj, prompt=prompt, model_response=model_response, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) - client = _get_httpx_client() + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client() try: response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore response.raise_for_status() @@ -99,13 +108,14 @@ class BedrockImageGeneration(BaseAWSLLM): logging_obj: LitellmLogging, prompt: str, model_response: ImageResponse, + client: Optional[AsyncHTTPHandler] = None, ) -> ImageResponse: """ Asynchronous handler for bedrock image generation Awaits the response from the bedrock image generation endpoint """ - async_client = get_async_httpx_client( + async_client = client or get_async_httpx_client( llm_provider=litellm.LlmProviders.BEDROCK, params={"timeout": timeout}, ) diff --git a/litellm/main.py b/litellm/main.py index 57bcda61fd..28dbf45102 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4521,6 +4521,7 @@ def image_generation( # noqa: PLR0915 non_default_params = { k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider + optional_params = get_optional_params_image_gen( model=model, n=n, @@ -4532,6 +4533,7 @@ def image_generation( # noqa: PLR0915 custom_llm_provider=custom_llm_provider, **non_default_params, ) + logging: Logging = litellm_logging_obj logging.update_environment_variables( model=model, @@ -4630,6 +4632,7 @@ def image_generation( # noqa: PLR0915 optional_params=optional_params, model_response=model_response, aimg_generation=aimg_generation, + client=client, ) elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 7a1d15fe01..0b34c575a1 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -11,6 +11,7 @@ import uuid from typing import TYPE_CHECKING, List, Optional, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger diff --git a/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py b/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py index 10845a895d..fa5fbeda7f 100644 --- a/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py +++ b/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py @@ -263,3 +263,25 @@ def test_cost_calculator_basic(): assert isinstance(cost, float) assert cost > 0 + + +def test_bedrock_image_gen_with_aws_region_name(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm import image_generation + + client = HTTPHandler() + + with patch.object(client, "post") as mock_post: + try: + image_generation( + model="bedrock/stability.stable-image-ultra-v1:1", + prompt="A beautiful sunset", + aws_region_name="us-west-2", + client=client, + ) + except Exception as e: + print(e) + raise e + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + print(kwargs)