forked from phoenix/litellm-mirror
feat(bedrock.py): add stable diffusion image generation support
This commit is contained in:
parent
30c96ee872
commit
36416360c4
6 changed files with 314 additions and 17 deletions
|
@ -554,6 +554,7 @@ from .llms.bedrock import (
|
|||
AmazonAnthropicConfig,
|
||||
AmazonCohereConfig,
|
||||
AmazonLlamaConfig,
|
||||
AmazonStabilityConfig,
|
||||
)
|
||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||
|
|
|
@ -2,9 +2,9 @@ import json, copy, types
|
|||
import os
|
||||
from enum import Enum
|
||||
import time
|
||||
from typing import Callable, Optional, Any, Union
|
||||
from typing import Callable, Optional, Any, Union, List
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
|
||||
|
@ -282,6 +282,73 @@ class AmazonLlamaConfig:
|
|||
}
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
|
@ -290,6 +357,7 @@ def init_bedrock_client(
|
|||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
@ -346,6 +414,8 @@ def init_bedrock_client(
|
|||
|
||||
import boto3
|
||||
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
|
@ -366,6 +436,7 @@ def init_bedrock_client(
|
|||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
|
@ -377,6 +448,7 @@ def init_bedrock_client(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
|
@ -386,6 +458,7 @@ def init_bedrock_client(
|
|||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return client
|
||||
|
@ -855,3 +928,112 @@ def embedding(
|
|||
model_response.usage = usage
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
def image_generation(
|
||||
model: str,
|
||||
prompt: str,
|
||||
timeout=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
"""
|
||||
Bedrock Image Gen endpoint support
|
||||
"""
|
||||
### BOTO3 INIT ###
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
### FORMAT IMAGE GENERATION INPUT ###
|
||||
modelId = model
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body},
|
||||
modelId={modelId},
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": modelId, "texts": prompt},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=modelId,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
)
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
if response_body is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response is None:
|
||||
model_response = ImageResponse()
|
||||
|
||||
image_list: List = []
|
||||
for artifact in response_body["artifacts"]:
|
||||
image_dict = {"url": artifact["base64"]}
|
||||
|
||||
model_response.data = image_dict
|
||||
return model_response
|
||||
|
|
|
@ -3153,7 +3153,18 @@ def image_generation(
|
|||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for bedrock")
|
||||
model_response = bedrock.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
|
|
|
@ -119,3 +119,43 @@ async def test_async_image_generation_azure():
|
|||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_image_generation_bedrock():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors - skip when they occur
|
||||
except Exception as e:
|
||||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimage_generation_bedrock_with_optional_params():
|
||||
try:
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
size="128x128",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors - skip when they occur
|
||||
except Exception as e:
|
||||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
|
|
@ -3445,21 +3445,42 @@ def get_optional_params_image_gen(
|
|||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||
if len(non_default_params.keys()) > 0:
|
||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
return non_default_params
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
optional_params = {}
|
||||
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = ["size"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if size is not None:
|
||||
width, height = size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
|
||||
|
||||
def get_optional_params_embeddings(
|
||||
|
|
|
@ -1477,6 +1477,48 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.018,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"512-x-512/max-steps/stability.stable-diffusion-xl-v0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.036,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"max-x-max/50-steps/stability.stable-diffusion-xl-v0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.036,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"max-x-max/max-steps/stability.stable-diffusion-xl-v0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.072,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"1024-x-1024/50-steps/stability.stable-diffusion-xl-v1": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.04,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"1024-x-1024/max-steps/stability.stable-diffusion-xl-v1": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.08,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue