diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index afd1fee39..579353d65 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -1082,5 +1082,6 @@ print(f"response: {response}") | Model Name | Function Call | |----------------------|---------------------------------------------| +| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` | | Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` | | Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` | \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index 5872c4a2f..b739afb93 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -988,6 +988,7 @@ from .llms.bedrock.common_utils import ( AmazonBedrockGlobalConfig, ) from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig +from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config from .llms.bedrock.embed.amazon_titan_multimodal_transformation import ( AmazonTitanMultimodalEmbeddingG1Config, diff --git a/litellm/llms/bedrock/image/amazon_stability1_transformation.py b/litellm/llms/bedrock/image/amazon_stability1_transformation.py index 83cccb947..880881e97 100644 --- a/litellm/llms/bedrock/image/amazon_stability1_transformation.py +++ b/litellm/llms/bedrock/image/amazon_stability1_transformation.py @@ -1,6 +1,10 @@ import types from typing import List, Optional +from openai.types.image import Image + +from litellm.types.utils import ImageResponse + class AmazonStabilityConfig: """ @@ -67,3 +71,34 @@ class AmazonStabilityConfig: ) and v is not None } + + @classmethod + def get_supported_openai_params(cls, model: Optional[str] = None) -> List: + return ["size"] + + @classmethod + def map_openai_params( + cls, + non_default_params: dict, + optional_params: dict, + ): + _size = non_default_params.get("size") + if _size is not None: + width, height = _size.split("x") + optional_params["width"] = int(width) + optional_params["height"] = int(height) + + return optional_params + + @classmethod + def transform_response_dict_to_openai_response( + cls, model_response: ImageResponse, response_dict: dict + ) -> ImageResponse: + image_list: List[Image] = [] + for artifact in response_dict["artifacts"]: + _image = Image(b64_json=artifact["base64"]) + image_list.append(_image) + + model_response.data = image_list + + return model_response diff --git a/litellm/llms/bedrock/image/amazon_stability3_transformation.py b/litellm/llms/bedrock/image/amazon_stability3_transformation.py new file mode 100644 index 000000000..784e86b04 --- /dev/null +++ b/litellm/llms/bedrock/image/amazon_stability3_transformation.py @@ -0,0 +1,94 @@ +import types +from typing import List, Optional + +from openai.types.image import Image + +from litellm.types.llms.bedrock import ( + AmazonStability3TextToImageRequest, + AmazonStability3TextToImageResponse, +) +from litellm.types.utils import ImageResponse + + +class AmazonStability3Config: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 + + Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post + """ + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + @classmethod + def get_supported_openai_params(cls, model: Optional[str] = None) -> List: + """ + No additional OpenAI params are mapped for stability 3 + """ + return [] + + @classmethod + def _is_stability_3_model(cls, model: Optional[str] = None) -> bool: + """ + Returns True if the model is a Stability 3 model + + Stability 3 models follow this pattern: + sd3-large + sd3-large-turbo + sd3-medium + sd3.5-large + sd3.5-large-turbo + """ + if model and ("sd3" in model or "sd3.5" in model): + return True + return False + + @classmethod + def transform_request_body( + cls, prompt: str, optional_params: dict + ) -> AmazonStability3TextToImageRequest: + """ + Transform the request body for the Stability 3 models + """ + data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params) + return data + + @classmethod + def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict: + """ + Map the OpenAI params to the Bedrock params + + No OpenAI params are mapped for Stability 3, so directly return the optional_params + """ + return optional_params + + @classmethod + def transform_response_dict_to_openai_response( + cls, model_response: ImageResponse, response_dict: dict + ) -> ImageResponse: + """ + Transform the response dict to the OpenAI response + """ + + stability_3_response = AmazonStability3TextToImageResponse(**response_dict) + openai_images: List[Image] = [] + for _img in stability_3_response.get("images", []): + openai_images.append(Image(b64_json=_img)) + + model_response.data = openai_images + return model_response diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index edf852fd3..31af2910f 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -183,28 +183,9 @@ class BedrockImageGeneration(BaseAWSLLM): boto3_credentials_info.aws_region_name, ) - # transform request - ### FORMAT IMAGE GENERATION INPUT ### - provider = model.split(".")[0] - inference_params = copy.deepcopy(optional_params) - inference_params.pop( - "user", None - ) # make sure user is not passed in for bedrock call - data = {} - if provider == "stability": - prompt = prompt.replace(os.linesep, " ") - ## LOAD CONFIG - config = litellm.AmazonStabilityConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} - else: - raise BedrockError( - status_code=422, message=f"Unsupported model={model}, passed in" - ) + data = self._get_request_body( + model=model, prompt=prompt, optional_params=optional_params + ) # Make POST Request body = json.dumps(data).encode("utf-8") @@ -239,6 +220,51 @@ class BedrockImageGeneration(BaseAWSLLM): data=data, ) + def _get_request_body( + self, + model: str, + prompt: str, + optional_params: dict, + ) -> dict: + """ + Get the request body for the Bedrock Image Generation API + + Checks the model/provider and transforms the request body accordingly + + Returns: + dict: The request body to use for the Bedrock Image Generation API + """ + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call + data = {} + if provider == "stability": + if litellm.AmazonStability3Config._is_stability_3_model(model): + request_body = litellm.AmazonStability3Config.transform_request_body( + prompt=prompt, optional_params=optional_params + ) + return dict(request_body) + else: + prompt = prompt.replace(os.linesep, " ") + ## LOAD CONFIG + config = litellm.AmazonStabilityConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = { + "text_prompts": [{"text": prompt, "weight": 1}], + **inference_params, + } + else: + raise BedrockError( + status_code=422, message=f"Unsupported model={model}, passed in" + ) + return data + def _transform_response_dict_to_openai_response( self, model_response: ImageResponse, @@ -265,11 +291,14 @@ class BedrockImageGeneration(BaseAWSLLM): if response_dict is None: raise ValueError("Error in response object format, got None") - image_list: List[Image] = [] - for artifact in response_dict["artifacts"]: - _image = Image(b64_json=artifact["base64"]) - image_list.append(_image) - - model_response.data = image_list + config_class = ( + litellm.AmazonStability3Config + if litellm.AmazonStability3Config._is_stability_3_model(model=model) + else litellm.AmazonStabilityConfig + ) + config_class.transform_response_dict_to_openai_response( + model_response=model_response, + response_dict=response_dict, + ) return model_response diff --git a/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py b/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py deleted file mode 100644 index a83b26226..000000000 --- a/litellm/llms/bedrock/image/stability_stable_diffusion1_transformation.py +++ /dev/null @@ -1,73 +0,0 @@ -import copy -import os -import types -from typing import Any, Dict, List, Optional, TypedDict, Union - -import litellm - - -class AmazonStability1Config: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 - - Supported Params for the Amazon / Stable Diffusion models: - - - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) - - - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) - - - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. - - - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - - - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - """ - - cfg_scale: Optional[int] = None - seed: Optional[float] = None - steps: Optional[List[str]] = None - width: Optional[int] = None - height: Optional[int] = None - - def __init__( - self, - cfg_scale: Optional[int] = None, - seed: Optional[float] = None, - steps: Optional[List[str]] = None, - width: Optional[int] = None, - height: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } diff --git a/litellm/main.py b/litellm/main.py index 5be596e94..afb46c698 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4448,6 +4448,7 @@ def image_generation( # noqa: PLR0915 k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider optional_params = get_optional_params_image_gen( + model=model, n=n, quality=quality, response_format=response_format, @@ -4540,7 +4541,7 @@ def image_generation( # noqa: PLR0915 elif custom_llm_provider == "bedrock": if model is None: raise Exception("Model needs to be set for bedrock") - model_response = bedrock_image_generation.image_generation( + model_response = bedrock_image_generation.image_generation( # type: ignore model=model, prompt=prompt, timeout=timeout, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index a9c65b2c9..6e57fd4ed 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -5611,6 +5611,13 @@ "litellm_provider": "bedrock", "mode": "image_generation" }, + "stability.stability.sd3-large-v1:0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.08, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "sagemaker/meta-textgeneration-llama-2-7b": { "max_tokens": 4096, "max_input_tokens": 4096, diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 737aac3c3..c80b16f6e 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -275,3 +275,32 @@ AmazonEmbeddingRequest = Union[ AmazonTitanV2EmbeddingRequest, AmazonTitanG1EmbeddingRequest, ] + + +class AmazonStability3TextToImageRequest(TypedDict, total=False): + """ + Request for Amazon Stability 3 Text to Image API + + Ref here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html + """ + + prompt: str + aspect_ratio: Literal[ + "16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21" + ] + mode: Literal["image-to-image", "text-to-image"] + output_format: Literal["JPEG", "PNG"] + seed: int + negative_prompt: str + + +class AmazonStability3TextToImageResponse(TypedDict, total=False): + """ + Response for Amazon Stability 3 Text to Image API + + Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html + """ + + images: List[str] + seeds: List[str] + finish_reasons: List[str] diff --git a/litellm/utils.py b/litellm/utils.py index e4e84398f..d07d86f7d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2174,6 +2174,7 @@ def get_optional_params_transcription( def get_optional_params_image_gen( + model: Optional[str] = None, n: Optional[int] = None, quality: Optional[str] = None, response_format: Optional[str] = None, @@ -2186,6 +2187,7 @@ def get_optional_params_image_gen( ): # retrieve all parameters passed to the function passed_params = locals() + model = passed_params.pop("model", None) custom_llm_provider = passed_params.pop("custom_llm_provider") additional_drop_params = passed_params.pop("additional_drop_params", None) special_params = passed_params.pop("kwargs") @@ -2232,7 +2234,7 @@ def get_optional_params_image_gen( elif k not in supported_params: raise UnsupportedParamsError( status_code=500, - message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + message=f"Setting `{k}` is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", ) return non_default_params @@ -2243,12 +2245,17 @@ def get_optional_params_image_gen( ): optional_params = non_default_params elif custom_llm_provider == "bedrock": - supported_params = ["size"] + # use stability3 config class if model is a stability3 model + config_class = ( + litellm.AmazonStability3Config + if litellm.AmazonStability3Config._is_stability_3_model(model=model) + else litellm.AmazonStabilityConfig + ) + supported_params = config_class.get_supported_openai_params(model=model) _check_valid_arg(supported_params=supported_params) - if size is not None: - width, height = size.split("x") - optional_params["width"] = int(width) - optional_params["height"] = int(height) + optional_params = config_class.map_openai_params( + non_default_params=non_default_params, optional_params={} + ) elif custom_llm_provider == "vertex_ai": supported_params = ["n"] """ diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index a9c65b2c9..6e57fd4ed 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -5611,6 +5611,13 @@ "litellm_provider": "bedrock", "mode": "image_generation" }, + "stability.stability.sd3-large-v1:0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.08, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "sagemaker/meta-textgeneration-llama-2-7b": { "max_tokens": 4096, "max_input_tokens": 4096, diff --git a/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py b/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py new file mode 100644 index 000000000..e04eb2a1a --- /dev/null +++ b/tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py @@ -0,0 +1,187 @@ +import logging +import os +import sys +import traceback + +from dotenv import load_dotenv +from openai.types.image import Image + +logging.basicConfig(level=logging.DEBUG) +load_dotenv() +import asyncio +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest + +import litellm +from litellm.llms.bedrock.image.amazon_stability3_transformation import ( + AmazonStability3Config, +) +from litellm.llms.bedrock.image.amazon_stability1_transformation import ( + AmazonStabilityConfig, +) +from litellm.types.llms.bedrock import ( + AmazonStability3TextToImageRequest, + AmazonStability3TextToImageResponse, +) +from litellm.types.utils import ImageResponse +from unittest.mock import MagicMock, patch +from litellm.llms.bedrock.image.image_handler import ( + BedrockImageGeneration, + BedrockImagePreparedRequest, +) + + +@pytest.mark.parametrize( + "model,expected", + [ + ("sd3-large", True), + ("sd3-large-turbo", True), + ("sd3-medium", True), + ("sd3.5-large", True), + ("sd3.5-large-turbo", True), + ("gpt-4", False), + (None, False), + ("other-model", False), + ], +) +def test_is_stability_3_model(model, expected): + result = AmazonStability3Config._is_stability_3_model(model) + assert result == expected + + +def test_transform_request_body(): + prompt = "A beautiful sunset" + optional_params = {"size": "1024x1024"} + + result = AmazonStability3Config.transform_request_body(prompt, optional_params) + + assert result["prompt"] == prompt + assert result["size"] == "1024x1024" + + +def test_map_openai_params(): + non_default_params = {"n": 2, "size": "1024x1024"} + optional_params = {"cfg_scale": 7} + + result = AmazonStability3Config.map_openai_params( + non_default_params, optional_params + ) + + assert result == optional_params + assert "n" not in result # OpenAI params should not be included + + +def test_transform_response_dict_to_openai_response(): + # Create a mock response + response_dict = {"images": ["base64_encoded_image_1", "base64_encoded_image_2"]} + model_response = ImageResponse() + + result = AmazonStability3Config.transform_response_dict_to_openai_response( + model_response, response_dict + ) + + assert isinstance(result, ImageResponse) + assert len(result.data) == 2 + assert all(hasattr(img, "b64_json") for img in result.data) + assert [img.b64_json for img in result.data] == response_dict["images"] + + +def test_amazon_stability_get_supported_openai_params(): + result = AmazonStabilityConfig.get_supported_openai_params() + assert result == ["size"] + + +def test_amazon_stability_map_openai_params(): + # Test with size parameter + non_default_params = {"size": "512x512"} + optional_params = {"cfg_scale": 7} + + result = AmazonStabilityConfig.map_openai_params( + non_default_params, optional_params + ) + + assert result["width"] == 512 + assert result["height"] == 512 + assert result["cfg_scale"] == 7 + + +def test_amazon_stability_transform_response(): + # Create a mock response + response_dict = { + "artifacts": [ + {"base64": "base64_encoded_image_1"}, + {"base64": "base64_encoded_image_2"}, + ] + } + model_response = ImageResponse() + + result = AmazonStabilityConfig.transform_response_dict_to_openai_response( + model_response, response_dict + ) + + assert isinstance(result, ImageResponse) + assert len(result.data) == 2 + assert all(hasattr(img, "b64_json") for img in result.data) + assert [img.b64_json for img in result.data] == [ + "base64_encoded_image_1", + "base64_encoded_image_2", + ] + + +def test_get_request_body_stability3(): + handler = BedrockImageGeneration() + prompt = "A beautiful sunset" + optional_params = {} + model = "stability.sd3-large" + + result = handler._get_request_body( + model=model, prompt=prompt, optional_params=optional_params + ) + + assert result["prompt"] == prompt + + +def test_get_request_body_stability(): + handler = BedrockImageGeneration() + prompt = "A beautiful sunset" + optional_params = {"cfg_scale": 7} + model = "stability.stable-diffusion-xl" + + result = handler._get_request_body( + model=model, prompt=prompt, optional_params=optional_params + ) + + assert result["text_prompts"][0]["text"] == prompt + assert result["text_prompts"][0]["weight"] == 1 + assert result["cfg_scale"] == 7 + + +def test_transform_response_dict_to_openai_response_stability3(): + handler = BedrockImageGeneration() + model_response = ImageResponse() + model = "stability.sd3-large" + logging_obj = MagicMock() + prompt = "A beautiful sunset" + + # Mock response for Stability AI SD3 + mock_response = MagicMock() + mock_response.text = '{"images": ["base64_image_1", "base64_image_2"]}' + mock_response.json.return_value = {"images": ["base64_image_1", "base64_image_2"]} + + result = handler._transform_response_dict_to_openai_response( + model_response=model_response, + model=model, + logging_obj=logging_obj, + prompt=prompt, + response=mock_response, + data={}, + ) + + assert isinstance(result, ImageResponse) + assert len(result.data) == 2 + assert all(hasattr(img, "b64_json") for img in result.data) + assert [img.b64_json for img in result.data] == ["base64_image_1", "base64_image_2"] diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index 85f619f2f..cf46f90bb 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -20,6 +20,81 @@ sys.path.insert( import pytest import litellm +import json +import tempfile + + +def get_vertex_ai_creds_json() -> dict: + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/vertex_key.json" + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + return service_account_key_data + + +def load_vertex_ai_credentials(): + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/vertex_key.json" + + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + # Write the updated content to the temporary files + json.dump(service_account_key_data, temp_file, indent=2) + + # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) def test_image_generation_openai(): @@ -163,12 +238,17 @@ async def test_async_image_generation_azure(): pytest.fail(f"An exception occurred - {str(e)}") -def test_image_generation_bedrock(): +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model", + ["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"], +) +def test_image_generation_bedrock(model): try: litellm.set_verbose = True response = litellm.image_generation( prompt="A cute baby sea otter", - model="bedrock/stability.stable-diffusion-xl-v1", + model=model, aws_region_name="us-west-2", ) @@ -213,7 +293,6 @@ from openai.types.image import Image @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_aimage_generation_vertex_ai(sync_mode): - from test_amazing_vertex_completion import load_vertex_ai_credentials litellm.set_verbose = True diff --git a/tests/image_gen_tests/vertex_key.json b/tests/image_gen_tests/vertex_key.json new file mode 100644 index 000000000..e2fd8512b --- /dev/null +++ b/tests/image_gen_tests/vertex_key.json @@ -0,0 +1,13 @@ +{ + "type": "service_account", + "project_id": "adroit-crow-413218", + "private_key_id": "", + "private_key": "", + "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", + "client_id": "104886546564708740969", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +}