diff --git a/litellm/__init__.py b/litellm/__init__.py index 0021daef0..f1f8325dd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -554,6 +554,7 @@ from .llms.bedrock import ( AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig, + AmazonStabilityConfig, ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index b67061c76..0fe9b72e3 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -2,9 +2,9 @@ import json, copy, types import os from enum import Enum import time -from typing import Callable, Optional, Any, Union +from typing import Callable, Optional, Any, Union, List import litellm -from litellm.utils import ModelResponse, get_secret, Usage +from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse from .prompt_templates.factory import prompt_factory, custom_prompt import httpx @@ -282,6 +282,73 @@ class AmazonLlamaConfig: } +class AmazonStabilityConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 + + Supported Params for the Amazon / Stable Diffusion models: + + - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) + + - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) + + - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. + + - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + + - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + """ + + cfg_scale: Optional[int] = None + seed: Optional[float] = None + steps: Optional[List[str]] = None + width: Optional[int] = None + height: Optional[int] = None + + def __init__( + self, + cfg_scale: Optional[int] = None, + seed: Optional[float] = None, + steps: Optional[List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def init_bedrock_client( region_name=None, aws_access_key_id: Optional[str] = None, @@ -290,6 +357,7 @@ def init_bedrock_client( aws_bedrock_runtime_endpoint: Optional[str] = None, aws_session_name: Optional[str] = None, aws_role_name: Optional[str] = None, + timeout: Optional[int] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) @@ -346,6 +414,8 @@ def init_bedrock_client( import boto3 + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + ### CHECK STS ### if aws_role_name is not None and aws_session_name is not None: # use sts if role name passed in @@ -366,6 +436,7 @@ def init_bedrock_client( aws_session_token=sts_response["Credentials"]["SessionToken"], region_name=region_name, endpoint_url=endpoint_url, + config=config, ) elif aws_access_key_id is not None: # uses auth params passed to completion @@ -377,6 +448,7 @@ def init_bedrock_client( aws_secret_access_key=aws_secret_access_key, region_name=region_name, endpoint_url=endpoint_url, + config=config, ) else: # aws_access_key_id is None, assume user is trying to auth using env variables @@ -386,6 +458,7 @@ def init_bedrock_client( service_name="bedrock-runtime", region_name=region_name, endpoint_url=endpoint_url, + config=config, ) return client @@ -855,3 +928,112 @@ def embedding( model_response.usage = usage return model_response + + +def image_generation( + model: str, + prompt: str, + timeout=None, + logging_obj=None, + model_response=None, + optional_params=None, + aimg_generation=False, +): + """ + Bedrock Image Gen endpoint support + """ + ### BOTO3 INIT ### + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) + + # use passed in BedrockRuntime.Client if provided, otherwise create a new one + client = init_bedrock_client( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, + timeout=timeout, + ) + + ### FORMAT IMAGE GENERATION INPUT ### + modelId = model + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call + data = {} + if provider == "stability": + prompt = prompt.replace(os.linesep, " ") + ## LOAD CONFIG + config = litellm.AmazonStabilityConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} + else: + raise BedrockError( + status_code=422, message=f"Unsupported model={model}, passed in" + ) + + body = json.dumps(data).encode("utf-8") + ## LOGGING + request_str = f""" + response = client.invoke_model( + body={body}, + modelId={modelId}, + accept="application/json", + contentType="application/json", + )""" # type: ignore + logging_obj.pre_call( + input=prompt, + api_key="", # boto3 is used for init. + additional_args={ + "complete_input_dict": {"model": modelId, "texts": prompt}, + "request_str": request_str, + }, + ) + try: + response = client.invoke_model( + body=body, + modelId=modelId, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=json.dumps(response_body), + ) + except Exception as e: + raise BedrockError( + message=f"Embedding Error with model {model}: {e}", status_code=500 + ) + + ### FORMAT RESPONSE TO OPENAI FORMAT ### + if response_body is None: + raise Exception("Error in response object format") + + if model_response is None: + model_response = ImageResponse() + + image_list: List = [] + for artifact in response_body["artifacts"]: + image_dict = {"url": artifact["base64"]} + + model_response.data = image_dict + return model_response diff --git a/litellm/main.py b/litellm/main.py index a30d6a8e4..1c8077396 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3153,7 +3153,18 @@ def image_generation( model_response=model_response, aimg_generation=aimg_generation, ) - + elif custom_llm_provider == "bedrock": + if model is None: + raise Exception("Model needs to be set for bedrock") + model_response = bedrock.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + aimg_generation=aimg_generation, + ) return model_response except Exception as e: ## Map to OpenAI Exception diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 54eba4cfd..59ccaacd8 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -119,3 +119,43 @@ async def test_async_image_generation_azure(): pass else: pytest.fail(f"An exception occurred - {str(e)}") + + +def test_image_generation_bedrock(): + try: + litellm.set_verbose = True + response = litellm.image_generation( + prompt="A cute baby sea otter", + model="bedrock/stability.stable-diffusion-xl-v0", + aws_region_name="us-east-1", + ) + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_aimage_generation_bedrock_with_optional_params(): + try: + response = await litellm.aimage_generation( + prompt="A cute baby sea otter", + model="bedrock/stability.stable-diffusion-xl-v0", + size="128x128", + ) + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}") diff --git a/litellm/utils.py b/litellm/utils.py index f265a0190..bf255562d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3445,21 +3445,42 @@ def get_optional_params_image_gen( for k, v in passed_params.items() if (k in default_params and v != default_params[k]) } - ## raise exception if non-default value passed for non-openai/azure embedding calls - if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values - keys = list(non_default_params.keys()) - for k in keys: - non_default_params.pop(k, None) - return non_default_params - raise UnsupportedParamsError( - status_code=500, - message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", - ) + optional_params = {} - final_params = {**non_default_params, **kwargs} - return final_params + ## raise exception if non-default value passed for non-openai/azure embedding calls + def _check_valid_arg(supported_params): + if len(non_default_params.keys()) > 0: + keys = list(non_default_params.keys()) + for k in keys: + if ( + litellm.drop_params is True and k not in supported_params + ): # drop the unsupported non-default values + non_default_params.pop(k, None) + elif k not in supported_params: + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + return non_default_params + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider in litellm.openai_compatible_providers + ): + optional_params = non_default_params + elif custom_llm_provider == "bedrock": + supported_params = ["size"] + _check_valid_arg(supported_params=supported_params) + if size is not None: + width, height = size.split("x") + optional_params["width"] = int(width) + optional_params["height"] = int(height) + + for k in passed_params.keys(): + if k not in default_params.keys(): + optional_params[k] = passed_params[k] + return optional_params def get_optional_params_embeddings( diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 7c2b1f6bb..5838935a2 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1477,6 +1477,48 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "512-x-512/50-steps/stability.stable-diffusion-xl-v0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.018, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, + "512-x-512/max-steps/stability.stable-diffusion-xl-v0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.036, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, + "max-x-max/50-steps/stability.stable-diffusion-xl-v0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.036, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, + "max-x-max/max-steps/stability.stable-diffusion-xl-v0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.072, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, + "1024-x-1024/50-steps/stability.stable-diffusion-xl-v1": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.04, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, + "1024-x-1024/max-steps/stability.stable-diffusion-xl-v1": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.08, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "sagemaker/meta-textgeneration-llama-2-7b": { "max_tokens": 4096, "input_cost_per_token": 0.000,