mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(base_aws_llm.py): remove region name before sending in args (#8998)
* fix(base_aws_llm.py): remove region name before sending in args * fix(base_aws_llm.py): fix optional param pop position * fix: fix linting error
This commit is contained in:
parent
4cd35205ae
commit
ecef915ac9
5 changed files with 39 additions and 2 deletions
|
@ -554,6 +554,7 @@ class BaseAWSLLM:
|
|||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
|
|
|
@ -10,6 +10,8 @@ import litellm
|
|||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
@ -51,6 +53,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
aimg_generation: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
|
@ -69,9 +72,15 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
client = _get_httpx_client()
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
@ -99,13 +108,14 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image generation
|
||||
|
||||
Awaits the response from the bedrock image generation endpoint
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
async_client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
|
|
@ -4521,6 +4521,7 @@ def image_generation( # noqa: PLR0915
|
|||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
|
||||
optional_params = get_optional_params_image_gen(
|
||||
model=model,
|
||||
n=n,
|
||||
|
@ -4532,6 +4533,7 @@ def image_generation( # noqa: PLR0915
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params,
|
||||
)
|
||||
|
||||
logging: Logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -4630,6 +4632,7 @@ def image_generation( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
|
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
|
|
@ -263,3 +263,25 @@ def test_cost_calculator_basic():
|
|||
|
||||
assert isinstance(cost, float)
|
||||
assert cost > 0
|
||||
|
||||
|
||||
def test_bedrock_image_gen_with_aws_region_name():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm import image_generation
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
image_generation(
|
||||
model="bedrock/stability.stable-image-ultra-v1:1",
|
||||
prompt="A beautiful sunset",
|
||||
aws_region_name="us-west-2",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
print(kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue