forked from phoenix/litellm-mirror
(feat) Add Bedrock Stability.ai Stable Diffusion 3 Image Generation models (#6673)
* add bedrock image gen async support * added async support for bedrock image gen * move image gen testing * add AmazonStability3Config * add AmazonStability3Config config * update AmazonStabilityConfig * update get_optional_params_image_gen * use 1 helper for _get_request_body * add transform_response_dict_to_openai_response for stability3 * test sd3-large-v1:0 * unit testing for bedrock image gen * fix load_vertex_ai_credentials * fix test_aimage_generation_vertex_ai * add stability.sd3-large-v1:0 to model cost map * add stability.stability.sd3-large-v1:0 to docs
This commit is contained in:
parent
0871c33a24
commit
979dfe8ab2
14 changed files with 528 additions and 111 deletions
|
@ -1082,5 +1082,6 @@ print(f"response: {response}")
|
|||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
|
|
@ -988,6 +988,7 @@ from .llms.bedrock.common_utils import (
|
|||
AmazonBedrockGlobalConfig,
|
||||
)
|
||||
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
||||
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
|
||||
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
|
@ -67,3 +71,34 @@ class AmazonStabilityConfig:
|
|||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
return ["size"]
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(
|
||||
cls,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonStability3TextToImageRequest,
|
||||
AmazonStability3TextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStability3Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
"""
|
||||
No additional OpenAI params are mapped for stability 3
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Stability 3 model
|
||||
|
||||
Stability 3 models follow this pattern:
|
||||
sd3-large
|
||||
sd3-large-turbo
|
||||
sd3-medium
|
||||
sd3.5-large
|
||||
sd3.5-large-turbo
|
||||
"""
|
||||
if model and ("sd3" in model or "sd3.5" in model):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, prompt: str, optional_params: dict
|
||||
) -> AmazonStability3TextToImageRequest:
|
||||
"""
|
||||
Transform the request body for the Stability 3 models
|
||||
"""
|
||||
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
|
||||
No OpenAI params are mapped for Stability 3, so directly return the optional_params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
|
||||
openai_images: List[Image] = []
|
||||
for _img in stability_3_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
|
@ -183,28 +183,9 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
|
||||
# transform request
|
||||
### FORMAT IMAGE GENERATION INPUT ###
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
data = self._get_request_body(
|
||||
model=model, prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
@ -239,6 +220,51 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
data=data,
|
||||
)
|
||||
|
||||
def _get_request_body(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request body for the Bedrock Image Generation API
|
||||
|
||||
Checks the model/provider and transforms the request body accordingly
|
||||
|
||||
Returns:
|
||||
dict: The request body to use for the Bedrock Image Generation API
|
||||
"""
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model):
|
||||
request_body = litellm.AmazonStability3Config.transform_request_body(
|
||||
prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
return dict(request_body)
|
||||
else:
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {
|
||||
"text_prompts": [{"text": prompt, "weight": 1}],
|
||||
**inference_params,
|
||||
}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
return data
|
||||
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
|
@ -265,11 +291,14 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
config_class = (
|
||||
litellm.AmazonStability3Config
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
)
|
||||
config_class.transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
import copy
|
||||
import os
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class AmazonStability1Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
|
@ -4448,6 +4448,7 @@ def image_generation( # noqa: PLR0915
|
|||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(
|
||||
model=model,
|
||||
n=n,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
|
@ -4540,7 +4541,7 @@ def image_generation( # noqa: PLR0915
|
|||
elif custom_llm_provider == "bedrock":
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for bedrock")
|
||||
model_response = bedrock_image_generation.image_generation(
|
||||
model_response = bedrock_image_generation.image_generation( # type: ignore
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
|
|
|
@ -5611,6 +5611,13 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stability.sd3-large-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.08,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -275,3 +275,32 @@ AmazonEmbeddingRequest = Union[
|
|||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanG1EmbeddingRequest,
|
||||
]
|
||||
|
||||
|
||||
class AmazonStability3TextToImageRequest(TypedDict, total=False):
|
||||
"""
|
||||
Request for Amazon Stability 3 Text to Image API
|
||||
|
||||
Ref here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
aspect_ratio: Literal[
|
||||
"16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"
|
||||
]
|
||||
mode: Literal["image-to-image", "text-to-image"]
|
||||
output_format: Literal["JPEG", "PNG"]
|
||||
seed: int
|
||||
negative_prompt: str
|
||||
|
||||
|
||||
class AmazonStability3TextToImageResponse(TypedDict, total=False):
|
||||
"""
|
||||
Response for Amazon Stability 3 Text to Image API
|
||||
|
||||
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html
|
||||
"""
|
||||
|
||||
images: List[str]
|
||||
seeds: List[str]
|
||||
finish_reasons: List[str]
|
||||
|
|
|
@ -2174,6 +2174,7 @@ def get_optional_params_transcription(
|
|||
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
model: Optional[str] = None,
|
||||
n: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
|
@ -2186,6 +2187,7 @@ def get_optional_params_image_gen(
|
|||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
model = passed_params.pop("model", None)
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
additional_drop_params = passed_params.pop("additional_drop_params", None)
|
||||
special_params = passed_params.pop("kwargs")
|
||||
|
@ -2232,7 +2234,7 @@ def get_optional_params_image_gen(
|
|||
elif k not in supported_params:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
message=f"Setting `{k}` is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
|
@ -2243,12 +2245,17 @@ def get_optional_params_image_gen(
|
|||
):
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = ["size"]
|
||||
# use stability3 config class if model is a stability3 model
|
||||
config_class = (
|
||||
litellm.AmazonStability3Config
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
)
|
||||
supported_params = config_class.get_supported_openai_params(model=model)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if size is not None:
|
||||
width, height = size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
optional_params = config_class.map_openai_params(
|
||||
non_default_params=non_default_params, optional_params={}
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
supported_params = ["n"]
|
||||
"""
|
||||
|
|
|
@ -5611,6 +5611,13 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stability.sd3-large-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.08,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
187
tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py
Normal file
187
tests/image_gen_tests/test_bedrock_image_gen_unit_tests.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai.types.image import Image
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
load_dotenv()
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.bedrock.image.amazon_stability3_transformation import (
|
||||
AmazonStability3Config,
|
||||
)
|
||||
from litellm.llms.bedrock.image.amazon_stability1_transformation import (
|
||||
AmazonStabilityConfig,
|
||||
)
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonStability3TextToImageRequest,
|
||||
AmazonStability3TextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
from unittest.mock import MagicMock, patch
|
||||
from litellm.llms.bedrock.image.image_handler import (
|
||||
BedrockImageGeneration,
|
||||
BedrockImagePreparedRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected",
|
||||
[
|
||||
("sd3-large", True),
|
||||
("sd3-large-turbo", True),
|
||||
("sd3-medium", True),
|
||||
("sd3.5-large", True),
|
||||
("sd3.5-large-turbo", True),
|
||||
("gpt-4", False),
|
||||
(None, False),
|
||||
("other-model", False),
|
||||
],
|
||||
)
|
||||
def test_is_stability_3_model(model, expected):
|
||||
result = AmazonStability3Config._is_stability_3_model(model)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_transform_request_body():
|
||||
prompt = "A beautiful sunset"
|
||||
optional_params = {"size": "1024x1024"}
|
||||
|
||||
result = AmazonStability3Config.transform_request_body(prompt, optional_params)
|
||||
|
||||
assert result["prompt"] == prompt
|
||||
assert result["size"] == "1024x1024"
|
||||
|
||||
|
||||
def test_map_openai_params():
|
||||
non_default_params = {"n": 2, "size": "1024x1024"}
|
||||
optional_params = {"cfg_scale": 7}
|
||||
|
||||
result = AmazonStability3Config.map_openai_params(
|
||||
non_default_params, optional_params
|
||||
)
|
||||
|
||||
assert result == optional_params
|
||||
assert "n" not in result # OpenAI params should not be included
|
||||
|
||||
|
||||
def test_transform_response_dict_to_openai_response():
|
||||
# Create a mock response
|
||||
response_dict = {"images": ["base64_encoded_image_1", "base64_encoded_image_2"]}
|
||||
model_response = ImageResponse()
|
||||
|
||||
result = AmazonStability3Config.transform_response_dict_to_openai_response(
|
||||
model_response, response_dict
|
||||
)
|
||||
|
||||
assert isinstance(result, ImageResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(hasattr(img, "b64_json") for img in result.data)
|
||||
assert [img.b64_json for img in result.data] == response_dict["images"]
|
||||
|
||||
|
||||
def test_amazon_stability_get_supported_openai_params():
|
||||
result = AmazonStabilityConfig.get_supported_openai_params()
|
||||
assert result == ["size"]
|
||||
|
||||
|
||||
def test_amazon_stability_map_openai_params():
|
||||
# Test with size parameter
|
||||
non_default_params = {"size": "512x512"}
|
||||
optional_params = {"cfg_scale": 7}
|
||||
|
||||
result = AmazonStabilityConfig.map_openai_params(
|
||||
non_default_params, optional_params
|
||||
)
|
||||
|
||||
assert result["width"] == 512
|
||||
assert result["height"] == 512
|
||||
assert result["cfg_scale"] == 7
|
||||
|
||||
|
||||
def test_amazon_stability_transform_response():
|
||||
# Create a mock response
|
||||
response_dict = {
|
||||
"artifacts": [
|
||||
{"base64": "base64_encoded_image_1"},
|
||||
{"base64": "base64_encoded_image_2"},
|
||||
]
|
||||
}
|
||||
model_response = ImageResponse()
|
||||
|
||||
result = AmazonStabilityConfig.transform_response_dict_to_openai_response(
|
||||
model_response, response_dict
|
||||
)
|
||||
|
||||
assert isinstance(result, ImageResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(hasattr(img, "b64_json") for img in result.data)
|
||||
assert [img.b64_json for img in result.data] == [
|
||||
"base64_encoded_image_1",
|
||||
"base64_encoded_image_2",
|
||||
]
|
||||
|
||||
|
||||
def test_get_request_body_stability3():
|
||||
handler = BedrockImageGeneration()
|
||||
prompt = "A beautiful sunset"
|
||||
optional_params = {}
|
||||
model = "stability.sd3-large"
|
||||
|
||||
result = handler._get_request_body(
|
||||
model=model, prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
|
||||
assert result["prompt"] == prompt
|
||||
|
||||
|
||||
def test_get_request_body_stability():
|
||||
handler = BedrockImageGeneration()
|
||||
prompt = "A beautiful sunset"
|
||||
optional_params = {"cfg_scale": 7}
|
||||
model = "stability.stable-diffusion-xl"
|
||||
|
||||
result = handler._get_request_body(
|
||||
model=model, prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
|
||||
assert result["text_prompts"][0]["text"] == prompt
|
||||
assert result["text_prompts"][0]["weight"] == 1
|
||||
assert result["cfg_scale"] == 7
|
||||
|
||||
|
||||
def test_transform_response_dict_to_openai_response_stability3():
|
||||
handler = BedrockImageGeneration()
|
||||
model_response = ImageResponse()
|
||||
model = "stability.sd3-large"
|
||||
logging_obj = MagicMock()
|
||||
prompt = "A beautiful sunset"
|
||||
|
||||
# Mock response for Stability AI SD3
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"images": ["base64_image_1", "base64_image_2"]}'
|
||||
mock_response.json.return_value = {"images": ["base64_image_1", "base64_image_2"]}
|
||||
|
||||
result = handler._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=mock_response,
|
||||
data={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ImageResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(hasattr(img, "b64_json") for img in result.data)
|
||||
assert [img.b64_json for img in result.data] == ["base64_image_1", "base64_image_2"]
|
|
@ -20,6 +20,81 @@ sys.path.insert(
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
|
||||
def get_vertex_ai_creds_json() -> dict:
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
return service_account_key_data
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||
# Write the updated content to the temporary files
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
def test_image_generation_openai():
|
||||
|
@ -163,12 +238,17 @@ async def test_async_image_generation_azure():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_image_generation_bedrock():
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"],
|
||||
)
|
||||
def test_image_generation_bedrock(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v1",
|
||||
model=model,
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
|
||||
|
@ -213,7 +293,6 @@ from openai.types.image import Image
|
|||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimage_generation_vertex_ai(sync_mode):
|
||||
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
|
13
tests/image_gen_tests/vertex_key.json
Normal file
13
tests/image_gen_tests/vertex_key.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"type": "service_account",
|
||||
"project_id": "adroit-crow-413218",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"client_id": "104886546564708740969",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue