added async support for bedrock image gen

This commit is contained in:
Ishaan Jaff 2024-11-08 14:04:04 -08:00
parent 64c3c4906c
commit 092888d593
5 changed files with 300 additions and 108 deletions

View file

@ -984,10 +984,10 @@ from .llms.bedrock.common_utils import (
AmazonAnthropicClaude3Config,
AmazonCohereConfig,
AmazonLlamaConfig,
AmazonStabilityConfig,
AmazonMistralConfig,
AmazonBedrockGlobalConfig,
)
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config,

View file

@ -484,73 +484,6 @@ class AmazonMistralConfig:
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def add_custom_header(headers):
"""Closure to capture the headers and add them."""

View file

@ -0,0 +1,69 @@
import types
from typing import List, Optional
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -1,38 +1,163 @@
import copy
import json
import os
from typing import Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from openai.types.image import Image
from pydantic import BaseModel
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, _get_httpx_client
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ImageResponse
from litellm.utils import print_verbose
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockImagePreparedRequest(BaseModel):
"""
Internal/Helper class for preparing the request for bedrock image generation
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
class BedrockImageGeneration(BaseAWSLLM):
"""
Bedrock Image Generation handler
"""
def image_generation( # noqa: PLR0915
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
logging_obj: LitellmLogging,
timeout: Optional[Union[float, httpx.Timeout]],
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Any] = None,
):
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
logging_obj=logging_obj,
prompt=prompt,
)
if aimg_generation is True:
return self.async_image_generation(
prepared_request=prepared_request,
timeout=timeout,
model=model,
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
)
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model_response=model_response,
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
)
return model_response
async def async_image_generation(
self,
prepared_request: BedrockImagePreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]],
model: str,
logging_obj: LitellmLogging,
prompt: str,
model_response: ImageResponse,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint
"""
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)
try:
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
model_response=model_response,
)
return model_response
def _prepare_request(
self,
model: str,
optional_params: dict,
api_base: Optional[str],
extra_headers: Optional[dict],
logging_obj: LitellmLogging,
prompt: str,
) -> BedrockImagePreparedRequest:
"""
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
Args:
model (str): The model to use for the image generation
optional_params (dict): The optional parameters for the image generation
api_base (Optional[str]): The base URL for the Bedrock API
extra_headers (Optional[dict]): The extra headers to include in the request
logging_obj (LitellmLogging): The logging object to use for logging
prompt (str): The prompt to use for the image generation
Returns:
BedrockImagePreparedRequest: The prepared request object
The BedrockImagePreparedRequest contains:
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
prepped (httpx.Request): The prepared request object
body (bytes): The request body
"""
try:
import boto3
from botocore.auth import SigV4Auth
@ -46,7 +171,7 @@ class BedrockImageGeneration(BaseAWSLLM):
### SET RUNTIME ENDPOINT ###
modelId = model
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
@ -107,27 +232,25 @@ class BedrockImageGeneration(BaseAWSLLM):
"headers": prepped.headers,
},
)
return BedrockImagePreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
try:
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
response_body = response.json()
def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
model: str,
logging_obj: LitellmLogging,
prompt: str,
response: httpx.Response,
data: dict,
) -> ImageResponse:
"""
Transforms the Image Generation response from Bedrock to OpenAI format
"""
## LOGGING
if logging_obj is not None:
@ -137,22 +260,16 @@ class BedrockImageGeneration(BaseAWSLLM):
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose("raw model_response: %s", response.text)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
verbose_logger.debug("raw model_response: %s", response.text)
response_dict = response.json()
if response_dict is None:
raise ValueError("Error in response object format, got None")
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response
async def async_image_generation(self):
pass
return model_response

View file

@ -0,0 +1,73 @@
import copy
import os
import types
from typing import Any, Dict, List, Optional, TypedDict, Union
import litellm
class AmazonStability1Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}