mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(base_aws_llm.py): remove region name before sending in args (#8998)
* fix(base_aws_llm.py): remove region name before sending in args * fix(base_aws_llm.py): fix optional param pop position * fix: fix linting error
This commit is contained in:
parent
4cd35205ae
commit
ecef915ac9
5 changed files with 39 additions and 2 deletions
|
@ -554,6 +554,7 @@ class BaseAWSLLM:
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||||
|
optional_params.pop("aws_region_name", None)
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
|
|
@ -10,6 +10,8 @@ import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
@ -51,6 +53,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
aimg_generation: bool = False,
|
aimg_generation: bool = False,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
prepared_request = self._prepare_request(
|
prepared_request = self._prepare_request(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -69,9 +72,15 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
client = _get_httpx_client()
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = _get_httpx_client()
|
||||||
try:
|
try:
|
||||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -99,13 +108,14 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
logging_obj: LitellmLogging,
|
logging_obj: LitellmLogging,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
"""
|
"""
|
||||||
Asynchronous handler for bedrock image generation
|
Asynchronous handler for bedrock image generation
|
||||||
|
|
||||||
Awaits the response from the bedrock image generation endpoint
|
Awaits the response from the bedrock image generation endpoint
|
||||||
"""
|
"""
|
||||||
async_client = get_async_httpx_client(
|
async_client = client or get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||||
params={"timeout": timeout},
|
params={"timeout": timeout},
|
||||||
)
|
)
|
||||||
|
|
|
@ -4521,6 +4521,7 @@ def image_generation( # noqa: PLR0915
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
|
|
||||||
optional_params = get_optional_params_image_gen(
|
optional_params = get_optional_params_image_gen(
|
||||||
model=model,
|
model=model,
|
||||||
n=n,
|
n=n,
|
||||||
|
@ -4532,6 +4533,7 @@ def image_generation( # noqa: PLR0915
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging: Logging = litellm_logging_obj
|
logging: Logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -4630,6 +4632,7 @@ def image_generation( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
|
|
|
@ -11,6 +11,7 @@ import uuid
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
|
@ -263,3 +263,25 @@ def test_cost_calculator_basic():
|
||||||
|
|
||||||
assert isinstance(cost, float)
|
assert isinstance(cost, float)
|
||||||
assert cost > 0
|
assert cost > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_image_gen_with_aws_region_name():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from litellm import image_generation
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
image_generation(
|
||||||
|
model="bedrock/stability.stable-image-ultra-v1:1",
|
||||||
|
prompt="A beautiful sunset",
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
args, kwargs = mock_post.call_args
|
||||||
|
print(kwargs)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue