forked from phoenix/litellm-mirror
added async support for bedrock image gen
This commit is contained in:
parent
64c3c4906c
commit
092888d593
5 changed files with 300 additions and 108 deletions
|
@ -984,10 +984,10 @@ from .llms.bedrock.common_utils import (
|
|||
AmazonAnthropicClaude3Config,
|
||||
AmazonCohereConfig,
|
||||
AmazonLlamaConfig,
|
||||
AmazonStabilityConfig,
|
||||
AmazonMistralConfig,
|
||||
AmazonBedrockGlobalConfig,
|
||||
)
|
||||
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
||||
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
|
|
|
@ -484,73 +484,6 @@ class AmazonMistralConfig:
|
|||
}
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
|
@ -1,38 +1,163 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, _get_httpx_client
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
from ...base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockImagePreparedRequest(BaseModel):
|
||||
"""
|
||||
Internal/Helper class for preparing the request for bedrock image generation
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
prepped: AWSPreparedRequest
|
||||
body: bytes
|
||||
data: dict
|
||||
|
||||
|
||||
class BedrockImageGeneration(BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Image Generation handler
|
||||
"""
|
||||
|
||||
def image_generation( # noqa: PLR0915
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
timeout=None,
|
||||
logging_obj: LitellmLogging,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aimg_generation: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Any] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.async_image_generation(
|
||||
prepared_request=prepared_request,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
)
|
||||
return model_response
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
prepared_request: BedrockImagePreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image generation
|
||||
|
||||
Awaits the response from the bedrock image generation endpoint
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
model_response=model_response,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
) -> BedrockImagePreparedRequest:
|
||||
"""
|
||||
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
|
||||
|
||||
Args:
|
||||
model (str): The model to use for the image generation
|
||||
optional_params (dict): The optional parameters for the image generation
|
||||
api_base (Optional[str]): The base URL for the Bedrock API
|
||||
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||
logging_obj (LitellmLogging): The logging object to use for logging
|
||||
prompt (str): The prompt to use for the image generation
|
||||
Returns:
|
||||
BedrockImagePreparedRequest: The prepared request object
|
||||
|
||||
The BedrockImagePreparedRequest contains:
|
||||
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
|
||||
prepped (httpx.Request): The prepared request object
|
||||
body (bytes): The request body
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
|
@ -46,7 +171,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
modelId = model
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
|
@ -107,27 +232,25 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
return BedrockImagePreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
response_body = response.json()
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
response: httpx.Response,
|
||||
data: dict,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transforms the Image Generation response from Bedrock to OpenAI format
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
|
@ -137,22 +260,16 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose("raw model_response: %s", response.text)
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
if response_body is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response is None:
|
||||
model_response = ImageResponse()
|
||||
verbose_logger.debug("raw model_response: %s", response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_body["artifacts"]:
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
return model_response
|
||||
|
||||
async def async_image_generation(self):
|
||||
pass
|
||||
return model_response
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import copy
|
||||
import os
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class AmazonStability1Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue