diff --git a/.circleci/config.yml b/.circleci/config.yml index 4bb232421..d2d83cd0e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -625,6 +625,48 @@ jobs: paths: - llm_translation_coverage.xml - llm_translation_coverage + image_gen_testing: + docker: + - image: cimg/python:3.11 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + + steps: + - checkout + - run: + name: Install Dependencies + command: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + pip install "pytest==7.3.1" + pip install "pytest-retry==1.6.3" + pip install "pytest-cov==5.0.0" + pip install "pytest-asyncio==0.21.1" + pip install "respx==0.21.1" + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/image_gen_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m + - run: + name: Rename the coverage files + command: | + mv coverage.xml image_gen_coverage.xml + mv .coverage image_gen_coverage + + # Store test results + - store_test_results: + path: test-results + - persist_to_workspace: + root: . + paths: + - image_gen_coverage.xml + - image_gen_coverage logging_testing: docker: - image: cimg/python:3.11 @@ -877,7 +919,7 @@ jobs: command: | pwd ls - python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation + python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests no_output_timeout: 120m # Store test results @@ -1114,7 +1156,7 @@ jobs: python -m venv venv . venv/bin/activate pip install coverage - coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage + coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage coverage xml - codecov/upload: file: ./coverage.xml @@ -1403,6 +1445,12 @@ workflows: only: - main - /litellm_.*/ + - image_gen_testing: + filters: + branches: + only: + - main + - /litellm_.*/ - logging_testing: filters: branches: @@ -1412,6 +1460,7 @@ workflows: - upload-coverage: requires: - llm_translation_testing + - image_gen_testing - logging_testing - litellm_router_testing - caching_unit_tests @@ -1451,6 +1500,7 @@ workflows: - load_testing - test_bad_database_url - llm_translation_testing + - image_gen_testing - logging_testing - litellm_router_testing - caching_unit_tests diff --git a/litellm/__init__.py b/litellm/__init__.py index 1951dd12f..5872c4a2f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -984,10 +984,10 @@ from .llms.bedrock.common_utils import ( AmazonAnthropicClaude3Config, AmazonCohereConfig, AmazonLlamaConfig, - AmazonStabilityConfig, AmazonMistralConfig, AmazonBedrockGlobalConfig, ) +from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config from .llms.bedrock.embed.amazon_titan_multimodal_transformation import ( AmazonTitanMultimodalEmbeddingG1Config, diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 70e3defc7..9f3a58a8b 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -1,16 +1,28 @@ import hashlib import json import os -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import httpx +from pydantic import BaseModel from litellm._logging import verbose_logger from litellm.caching.caching import DualCache, InMemoryCache -from litellm.secret_managers.main import get_secret +from litellm.secret_managers.main import get_secret, get_secret_str from .base import BaseLLM +if TYPE_CHECKING: + from botocore.credentials import Credentials +else: + Credentials = Any + + +class Boto3CredentialsInfo(BaseModel): + credentials: Credentials + aws_region_name: str + aws_bedrock_runtime_endpoint: Optional[str] + class AwsAuthError(Exception): def __init__(self, status_code, message): @@ -311,3 +323,74 @@ class BaseAWSLLM(BaseLLM): proxy_endpoint_url = endpoint_url return endpoint_url, proxy_endpoint_url + + def _get_boto_credentials_from_optional_params( + self, optional_params: dict + ) -> Boto3CredentialsInfo: + """ + Get boto3 credentials from optional params + + Args: + optional_params (dict): Optional parameters for the model call + + Returns: + Credentials: Boto3 credentials object + """ + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret_str("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + + return Boto3CredentialsInfo( + credentials=credentials, + aws_region_name=aws_region_name, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + ) diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 1ae74e535..332b1e2b3 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -484,73 +484,6 @@ class AmazonMistralConfig: } -class AmazonStabilityConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 - - Supported Params for the Amazon / Stable Diffusion models: - - - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) - - - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) - - - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. - - - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - - - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - """ - - cfg_scale: Optional[int] = None - seed: Optional[float] = None - steps: Optional[List[str]] = None - width: Optional[int] = None - height: Optional[int] = None - - def __init__( - self, - cfg_scale: Optional[int] = None, - seed: Optional[float] = None, - steps: Optional[List[str]] = None, - width: Optional[int] = None, - height: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def add_custom_header(headers): """Closure to capture the headers and add them.""" diff --git a/litellm/llms/bedrock/image/amazon_stability1_transformation.py b/litellm/llms/bedrock/image/amazon_stability1_transformation.py new file mode 100644 index 000000000..83cccb947 --- /dev/null +++ b/litellm/llms/bedrock/image/amazon_stability1_transformation.py @@ -0,0 +1,69 @@ +import types +from typing import List, Optional + + +class AmazonStabilityConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 + + Supported Params for the Amazon / Stable Diffusion models: + + - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) + + - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) + + - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. + + - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + + - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + """ + + cfg_scale: Optional[int] = None + seed: Optional[float] = None + steps: Optional[List[str]] = None + width: Optional[int] = None + height: Optional[int] = None + + def __init__( + self, + cfg_scale: Optional[int] = None, + seed: Optional[float] = None, + steps: Optional[List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py new file mode 100644 index 000000000..edf852fd3 --- /dev/null +++ b/litellm/llms/bedrock/image/image_handler.py @@ -0,0 +1,275 @@ +import copy +import json +import os +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx +from openai.types.image import Image +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.types.utils import ImageResponse + +from ...base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError + +if TYPE_CHECKING: + from botocore.awsrequest import AWSPreparedRequest +else: + AWSPreparedRequest = Any + + +class BedrockImagePreparedRequest(BaseModel): + """ + Internal/Helper class for preparing the request for bedrock image generation + """ + + endpoint_url: str + prepped: AWSPreparedRequest + body: bytes + data: dict + + +class BedrockImageGeneration(BaseAWSLLM): + """ + Bedrock Image Generation handler + """ + + def image_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: LitellmLogging, + timeout: Optional[Union[float, httpx.Timeout]], + aimg_generation: bool = False, + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None, + ): + prepared_request = self._prepare_request( + model=model, + optional_params=optional_params, + api_base=api_base, + extra_headers=extra_headers, + logging_obj=logging_obj, + prompt=prompt, + ) + + if aimg_generation is True: + return self.async_image_generation( + prepared_request=prepared_request, + timeout=timeout, + model=model, + logging_obj=logging_obj, + prompt=prompt, + model_response=model_response, + ) + + client = _get_httpx_client() + try: + response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + ### FORMAT RESPONSE TO OPENAI FORMAT ### + model_response = self._transform_response_dict_to_openai_response( + model_response=model_response, + model=model, + logging_obj=logging_obj, + prompt=prompt, + response=response, + data=prepared_request.data, + ) + return model_response + + async def async_image_generation( + self, + prepared_request: BedrockImagePreparedRequest, + timeout: Optional[Union[float, httpx.Timeout]], + model: str, + logging_obj: LitellmLogging, + prompt: str, + model_response: ImageResponse, + ) -> ImageResponse: + """ + Asynchronous handler for bedrock image generation + + Awaits the response from the bedrock image generation endpoint + """ + async_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.BEDROCK, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + ### FORMAT RESPONSE TO OPENAI FORMAT ### + model_response = self._transform_response_dict_to_openai_response( + model=model, + logging_obj=logging_obj, + prompt=prompt, + response=response, + data=prepared_request.data, + model_response=model_response, + ) + return model_response + + def _prepare_request( + self, + model: str, + optional_params: dict, + api_base: Optional[str], + extra_headers: Optional[dict], + logging_obj: LitellmLogging, + prompt: str, + ) -> BedrockImagePreparedRequest: + """ + Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API + + Args: + model (str): The model to use for the image generation + optional_params (dict): The optional parameters for the image generation + api_base (Optional[str]): The base URL for the Bedrock API + extra_headers (Optional[dict]): The extra headers to include in the request + logging_obj (LitellmLogging): The logging object to use for logging + prompt (str): The prompt to use for the image generation + Returns: + BedrockImagePreparedRequest: The prepared request object + + The BedrockImagePreparedRequest contains: + endpoint_url (str): The endpoint URL for the Bedrock Image Generation API + prepped (httpx.Request): The prepared request object + body (bytes): The request body + """ + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + boto3_credentials_info = self._get_boto_credentials_from_optional_params( + optional_params + ) + + ### SET RUNTIME ENDPOINT ### + modelId = model + _, proxy_endpoint_url = self.get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, + aws_region_name=boto3_credentials_info.aws_region_name, + ) + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke" + sigv4 = SigV4Auth( + boto3_credentials_info.credentials, + "bedrock", + boto3_credentials_info.aws_region_name, + ) + + # transform request + ### FORMAT IMAGE GENERATION INPUT ### + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call + data = {} + if provider == "stability": + prompt = prompt.replace(os.linesep, " ") + ## LOAD CONFIG + config = litellm.AmazonStabilityConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} + else: + raise BedrockError( + status_code=422, message=f"Unsupported model={model}, passed in" + ) + + # Make POST Request + body = json.dumps(data).encode("utf-8") + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=proxy_endpoint_url, data=body, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": proxy_endpoint_url, + "headers": prepped.headers, + }, + ) + return BedrockImagePreparedRequest( + endpoint_url=proxy_endpoint_url, + prepped=prepped, + body=body, + data=data, + ) + + def _transform_response_dict_to_openai_response( + self, + model_response: ImageResponse, + model: str, + logging_obj: LitellmLogging, + prompt: str, + response: httpx.Response, + data: dict, + ) -> ImageResponse: + """ + Transforms the Image Generation response from Bedrock to OpenAI format + """ + + ## LOGGING + if logging_obj is not None: + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + verbose_logger.debug("raw model_response: %s", response.text) + response_dict = response.json() + if response_dict is None: + raise ValueError("Error in response object format, got None") + + image_list: List[Image] = [] + for artifact in response_dict["artifacts"]: + _image = Image(b64_json=artifact["base64"]) + image_list.append(_image) + + model_response.data = image_list + + return model_response diff --git a/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py b/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py new file mode 100644 index 000000000..a83b26226 --- /dev/null +++ b/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py @@ -0,0 +1,73 @@ +import copy +import os +import types +from typing import Any, Dict, List, Optional, TypedDict, Union + +import litellm + + +class AmazonStability1Config: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 + + Supported Params for the Amazon / Stable Diffusion models: + + - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) + + - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) + + - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. + + - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + + - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + """ + + cfg_scale: Optional[int] = None + seed: Optional[float] = None + steps: Optional[List[str]] = None + width: Optional[int] = None + height: Optional[int] = None + + def __init__( + self, + cfg_scale: Optional[int] = None, + seed: Optional[float] = None, + steps: Optional[List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } diff --git a/litellm/llms/bedrock/image_generation.py b/litellm/llms/bedrock/image_generation.py deleted file mode 100644 index 65038d12e..000000000 --- a/litellm/llms/bedrock/image_generation.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Handles image gen calls to Bedrock's `/invoke` endpoint -""" - -import copy -import json -import os -from typing import Any, List - -from openai.types.image import Image - -import litellm -from litellm.types.utils import ImageResponse - -from .common_utils import BedrockError, init_bedrock_client - - -def image_generation( - model: str, - prompt: str, - model_response: ImageResponse, - optional_params: dict, - logging_obj: Any, - timeout=None, - aimg_generation=False, -): - """ - Bedrock Image Gen endpoint support - """ - ### BOTO3 INIT ### - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # use passed in BedrockRuntime.Client if provided, otherwise create a new one - client = init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_web_identity_token=aws_web_identity_token, - aws_role_name=aws_role_name, - aws_session_name=aws_session_name, - timeout=timeout, - ) - - ### FORMAT IMAGE GENERATION INPUT ### - modelId = model - provider = model.split(".")[0] - inference_params = copy.deepcopy(optional_params) - inference_params.pop( - "user", None - ) # make sure user is not passed in for bedrock call - data = {} - if provider == "stability": - prompt = prompt.replace(os.linesep, " ") - ## LOAD CONFIG - config = litellm.AmazonStabilityConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} - else: - raise BedrockError( - status_code=422, message=f"Unsupported model={model}, passed in" - ) - - body = json.dumps(data).encode("utf-8") - ## LOGGING - request_str = f""" - response = client.invoke_model( - body={body}, # type: ignore - modelId={modelId}, - accept="application/json", - contentType="application/json", - )""" # type: ignore - logging_obj.pre_call( - input=prompt, - api_key="", # boto3 is used for init. - additional_args={ - "complete_input_dict": {"model": modelId, "texts": prompt}, - "request_str": request_str, - }, - ) - try: - response = client.invoke_model( - body=body, - modelId=modelId, - accept="application/json", - contentType="application/json", - ) - response_body = json.loads(response.get("body").read()) - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=json.dumps(response_body), - ) - except Exception as e: - raise BedrockError( - message=f"Embedding Error with model {model}: {e}", status_code=500 - ) - - ### FORMAT RESPONSE TO OPENAI FORMAT ### - if response_body is None: - raise Exception("Error in response object format") - - if model_response is None: - model_response = ImageResponse() - - image_list: List[Image] = [] - for artifact in response_body["artifacts"]: - _image = Image(b64_json=artifact["base64"]) - image_list.append(_image) - - model_response.data = image_list - return model_response diff --git a/litellm/main.py b/litellm/main.py index 8334f35d7..5be596e94 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -108,9 +108,9 @@ from .llms.azure_text import AzureTextCompletion from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion -from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.embed.embedding import BedrockEmbedding +from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.cohere import chat as cohere_chat from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere.embed import handler as cohere_embed @@ -214,6 +214,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_embedding = BedrockEmbedding() +bedrock_image_generation = BedrockImageGeneration() vertex_chat_completion = VertexLLM() vertex_embedding = VertexEmbedding() vertex_multimodal_embedding = VertexMultimodalEmbedding() diff --git a/tests/local_testing/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py similarity index 100% rename from tests/local_testing/test_image_generation.py rename to tests/image_gen_tests/test_image_generation.py