Merge pull request #6672 from BerriAI/litellm_add_async_bedrock_image_gen

(feat) add bedrock image gen async support
This commit is contained in:
Ishaan Jaff 2024-11-08 19:25:02 -08:00 committed by GitHub
commit 0871c33a24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 557 additions and 200 deletions

View file

@ -625,6 +625,48 @@ jobs:
paths: paths:
- llm_translation_coverage.xml - llm_translation_coverage.xml
- llm_translation_coverage - llm_translation_coverage
image_gen_testing:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-cov==5.0.0"
pip install "pytest-asyncio==0.21.1"
pip install "respx==0.21.1"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/image_gen_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml image_gen_coverage.xml
mv .coverage image_gen_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- image_gen_coverage.xml
- image_gen_coverage
logging_testing: logging_testing:
docker: docker:
- image: cimg/python:3.11 - image: cimg/python:3.11
@ -877,7 +919,7 @@ jobs:
command: | command: |
pwd pwd
ls ls
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests
no_output_timeout: 120m no_output_timeout: 120m
# Store test results # Store test results
@ -1114,7 +1156,7 @@ jobs:
python -m venv venv python -m venv venv
. venv/bin/activate . venv/bin/activate
pip install coverage pip install coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage
coverage xml coverage xml
- codecov/upload: - codecov/upload:
file: ./coverage.xml file: ./coverage.xml
@ -1403,6 +1445,12 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- image_gen_testing:
filters:
branches:
only:
- main
- /litellm_.*/
- logging_testing: - logging_testing:
filters: filters:
branches: branches:
@ -1412,6 +1460,7 @@ workflows:
- upload-coverage: - upload-coverage:
requires: requires:
- llm_translation_testing - llm_translation_testing
- image_gen_testing
- logging_testing - logging_testing
- litellm_router_testing - litellm_router_testing
- caching_unit_tests - caching_unit_tests
@ -1451,6 +1500,7 @@ workflows:
- load_testing - load_testing
- test_bad_database_url - test_bad_database_url
- llm_translation_testing - llm_translation_testing
- image_gen_testing
- logging_testing - logging_testing
- litellm_router_testing - litellm_router_testing
- caching_unit_tests - caching_unit_tests

View file

@ -984,10 +984,10 @@ from .llms.bedrock.common_utils import (
AmazonAnthropicClaude3Config, AmazonAnthropicClaude3Config,
AmazonCohereConfig, AmazonCohereConfig,
AmazonLlamaConfig, AmazonLlamaConfig,
AmazonStabilityConfig,
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig, AmazonBedrockGlobalConfig,
) )
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import ( from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config, AmazonTitanMultimodalEmbeddingG1Config,

View file

@ -1,16 +1,28 @@
import hashlib import hashlib
import json import json
import os import os
from typing import Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching import DualCache, InMemoryCache
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret, get_secret_str
from .base import BaseLLM from .base import BaseLLM
if TYPE_CHECKING:
from botocore.credentials import Credentials
else:
Credentials = Any
class Boto3CredentialsInfo(BaseModel):
credentials: Credentials
aws_region_name: str
aws_bedrock_runtime_endpoint: Optional[str]
class AwsAuthError(Exception): class AwsAuthError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -311,3 +323,74 @@ class BaseAWSLLM(BaseLLM):
proxy_endpoint_url = endpoint_url proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
self, optional_params: dict
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
Args:
optional_params (dict): Optional parameters for the model call
Returns:
Credentials: Boto3 credentials object
"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return Boto3CredentialsInfo(
credentials=credentials,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)

View file

@ -484,73 +484,6 @@ class AmazonMistralConfig:
} }
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def add_custom_header(headers): def add_custom_header(headers):
"""Closure to capture the headers and add them.""" """Closure to capture the headers and add them."""

View file

@ -0,0 +1,69 @@
import types
from typing import List, Optional
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -0,0 +1,275 @@
import copy
import json
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from openai.types.image import Image
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ImageResponse
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockImagePreparedRequest(BaseModel):
"""
Internal/Helper class for preparing the request for bedrock image generation
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
class BedrockImageGeneration(BaseAWSLLM):
"""
Bedrock Image Generation handler
"""
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: LitellmLogging,
timeout: Optional[Union[float, httpx.Timeout]],
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
):
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
logging_obj=logging_obj,
prompt=prompt,
)
if aimg_generation is True:
return self.async_image_generation(
prepared_request=prepared_request,
timeout=timeout,
model=model,
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
)
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model_response=model_response,
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
)
return model_response
async def async_image_generation(
self,
prepared_request: BedrockImagePreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]],
model: str,
logging_obj: LitellmLogging,
prompt: str,
model_response: ImageResponse,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint
"""
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)
try:
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
model_response=model_response,
)
return model_response
def _prepare_request(
self,
model: str,
optional_params: dict,
api_base: Optional[str],
extra_headers: Optional[dict],
logging_obj: LitellmLogging,
prompt: str,
) -> BedrockImagePreparedRequest:
"""
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
Args:
model (str): The model to use for the image generation
optional_params (dict): The optional parameters for the image generation
api_base (Optional[str]): The base URL for the Bedrock API
extra_headers (Optional[dict]): The extra headers to include in the request
logging_obj (LitellmLogging): The logging object to use for logging
prompt (str): The prompt to use for the image generation
Returns:
BedrockImagePreparedRequest: The prepared request object
The BedrockImagePreparedRequest contains:
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
prepped (httpx.Request): The prepared request object
body (bytes): The request body
"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
### SET RUNTIME ENDPOINT ###
modelId = model
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
# transform request
### FORMAT IMAGE GENERATION INPUT ###
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
return BedrockImagePreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)
def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
model: str,
logging_obj: LitellmLogging,
prompt: str,
response: httpx.Response,
data: dict,
) -> ImageResponse:
"""
Transforms the Image Generation response from Bedrock to OpenAI format
"""
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
verbose_logger.debug("raw model_response: %s", response.text)
response_dict = response.json()
if response_dict is None:
raise ValueError("Error in response object format, got None")
image_list: List[Image] = []
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -0,0 +1,73 @@
import copy
import os
import types
from typing import Any, Dict, List, Optional, TypedDict, Union
import litellm
class AmazonStability1Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -1,127 +0,0 @@
"""
Handles image gen calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from typing import Any, List
from openai.types.image import Image
import litellm
from litellm.types.utils import ImageResponse
from .common_utils import BedrockError, init_bedrock_client
def image_generation(
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body}, # type: ignore
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -108,9 +108,9 @@ from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.cohere import chat as cohere_chat from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
@ -214,6 +214,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding() bedrock_embedding = BedrockEmbedding()
bedrock_image_generation = BedrockImageGeneration()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding() vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_multimodal_embedding = VertexMultimodalEmbedding()