feat(bedrock.py): add stable diffusion image generation support

This commit is contained in:
Krrish Dholakia 2024-02-03 12:08:38 -08:00
parent 30c96ee872
commit 36416360c4
6 changed files with 314 additions and 17 deletions

View file

@ -554,6 +554,7 @@ from .llms.bedrock import (
AmazonAnthropicConfig,
AmazonCohereConfig,
AmazonLlamaConfig,
AmazonStabilityConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError

View file

@ -2,9 +2,9 @@ import json, copy, types
import os
from enum import Enum
import time
from typing import Callable, Optional, Any, Union
from typing import Callable, Optional, Any, Union, List
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
@ -282,6 +282,73 @@ class AmazonLlamaConfig:
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def init_bedrock_client(
region_name=None,
aws_access_key_id: Optional[str] = None,
@ -290,6 +357,7 @@ def init_bedrock_client(
aws_bedrock_runtime_endpoint: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
timeout: Optional[int] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
@ -346,6 +414,8 @@ def init_bedrock_client(
import boto3
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
@ -366,6 +436,7 @@ def init_bedrock_client(
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
@ -377,6 +448,7 @@ def init_bedrock_client(
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
@ -386,6 +458,7 @@ def init_bedrock_client(
service_name="bedrock-runtime",
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
return client
@ -855,3 +928,112 @@ def embedding(
model_response.usage = usage
return model_response
def image_generation(
model: str,
prompt: str,
timeout=None,
logging_obj=None,
model_response=None,
optional_params=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body},
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List = []
for artifact in response_body["artifacts"]:
image_dict = {"url": artifact["base64"]}
model_response.data = image_dict
return model_response

View file

@ -3153,7 +3153,18 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
)
return model_response
except Exception as e:
## Map to OpenAI Exception

View file

@ -119,3 +119,43 @@ async def test_async_image_generation_azure():
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
def test_image_generation_bedrock():
try:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
aws_region_name="us-east-1",
)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
size="128x128",
)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -3445,21 +3445,42 @@ def get_optional_params_image_gen(
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
return non_default_params
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
optional_params = {}
final_params = {**non_default_params, **kwargs}
return final_params
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider in litellm.openai_compatible_providers
):
optional_params = non_default_params
elif custom_llm_provider == "bedrock":
supported_params = ["size"]
_check_valid_arg(supported_params=supported_params)
if size is not None:
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params
def get_optional_params_embeddings(

View file

@ -1477,6 +1477,48 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.018,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"512-x-512/max-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.036,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"max-x-max/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.036,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"max-x-max/max-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.072,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"1024-x-1024/50-steps/stability.stable-diffusion-xl-v1": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.04,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"1024-x-1024/max-steps/stability.stable-diffusion-xl-v1": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.08,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,