forked from phoenix/litellm-mirror
200 lines
6.8 KiB
Python
200 lines
6.8 KiB
Python
# What this tests?
|
|
## This tests the litellm support for the openai /generations endpoint
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from dotenv import load_dotenv
|
|
from openai.types.image import Image
|
|
from litellm.caching import InMemoryCache
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
load_dotenv()
|
|
import asyncio
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import pytest
|
|
|
|
import litellm
|
|
import json
|
|
import tempfile
|
|
from base_image_generation_test import BaseImageGenTest
|
|
import logging
|
|
from litellm._logging import verbose_logger
|
|
|
|
verbose_logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
def get_vertex_ai_creds_json() -> dict:
|
|
# Define the path to the vertex_key.json file
|
|
print("loading vertex ai credentials")
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
vertex_key_path = filepath + "/vertex_key.json"
|
|
# Read the existing content of the file or create an empty dictionary
|
|
try:
|
|
with open(vertex_key_path, "r") as file:
|
|
# Read the file content
|
|
print("Read vertexai file path")
|
|
content = file.read()
|
|
|
|
# If the file is empty or not valid JSON, create an empty dictionary
|
|
if not content or not content.strip():
|
|
service_account_key_data = {}
|
|
else:
|
|
# Attempt to load the existing JSON content
|
|
file.seek(0)
|
|
service_account_key_data = json.load(file)
|
|
except FileNotFoundError:
|
|
# If the file doesn't exist, create an empty dictionary
|
|
service_account_key_data = {}
|
|
|
|
# Update the service_account_key_data with environment variables
|
|
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
|
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
|
private_key = private_key.replace("\\n", "\n")
|
|
service_account_key_data["private_key_id"] = private_key_id
|
|
service_account_key_data["private_key"] = private_key
|
|
|
|
return service_account_key_data
|
|
|
|
|
|
def load_vertex_ai_credentials():
|
|
# Define the path to the vertex_key.json file
|
|
print("loading vertex ai credentials")
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
vertex_key_path = filepath + "/vertex_key.json"
|
|
|
|
# Read the existing content of the file or create an empty dictionary
|
|
try:
|
|
with open(vertex_key_path, "r") as file:
|
|
# Read the file content
|
|
print("Read vertexai file path")
|
|
content = file.read()
|
|
|
|
# If the file is empty or not valid JSON, create an empty dictionary
|
|
if not content or not content.strip():
|
|
service_account_key_data = {}
|
|
else:
|
|
# Attempt to load the existing JSON content
|
|
file.seek(0)
|
|
service_account_key_data = json.load(file)
|
|
except FileNotFoundError:
|
|
# If the file doesn't exist, create an empty dictionary
|
|
service_account_key_data = {}
|
|
|
|
# Update the service_account_key_data with environment variables
|
|
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
|
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
|
private_key = private_key.replace("\\n", "\n")
|
|
service_account_key_data["private_key_id"] = private_key_id
|
|
service_account_key_data["private_key"] = private_key
|
|
|
|
# Create a temporary file
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
|
# Write the updated content to the temporary files
|
|
json.dump(service_account_key_data, temp_file, indent=2)
|
|
|
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
|
|
|
|
|
class TestVertexImageGeneration(BaseImageGenTest):
|
|
def get_base_image_generation_call_args(self) -> dict:
|
|
# comment this when running locally
|
|
load_vertex_ai_credentials()
|
|
|
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
|
return {
|
|
"model": "vertex_ai/imagegeneration@006",
|
|
"vertex_ai_project": "adroit-crow-413218",
|
|
"vertex_ai_location": "us-central1",
|
|
"n": 1,
|
|
}
|
|
|
|
|
|
class TestBedrockSd3(BaseImageGenTest):
|
|
def get_base_image_generation_call_args(self) -> dict:
|
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
|
|
|
|
|
class TestBedrockSd1(BaseImageGenTest):
|
|
def get_base_image_generation_call_args(self) -> dict:
|
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
|
|
|
|
|
class TestOpenAIDalle3(BaseImageGenTest):
|
|
def get_base_image_generation_call_args(self) -> dict:
|
|
return {"model": "dall-e-3"}
|
|
|
|
|
|
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
|
def get_base_image_generation_call_args(self) -> dict:
|
|
litellm.set_verbose = True
|
|
return {
|
|
"model": "azure/dall-e-3-test",
|
|
"api_version": "2023-09-01-preview",
|
|
"metadata": {
|
|
"model_info": {
|
|
"base_model": "dall-e-3",
|
|
}
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.mark.flaky(retries=3, delay=1)
|
|
def test_image_generation_azure_dall_e_3():
|
|
try:
|
|
litellm.set_verbose = True
|
|
response = litellm.image_generation(
|
|
prompt="A cute baby sea otter",
|
|
model="azure/dall-e-3-test",
|
|
api_version="2023-12-01-preview",
|
|
api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
|
|
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
|
|
)
|
|
print(f"response: {response}")
|
|
assert len(response.data) > 0
|
|
except litellm.InternalServerError as e:
|
|
pass
|
|
except litellm.ContentPolicyViolationError:
|
|
pass # OpenAI randomly raises these errors - skip when they occur
|
|
except litellm.InternalServerError:
|
|
pass
|
|
except Exception as e:
|
|
if "Your task failed as a result of our safety system." in str(e):
|
|
pass
|
|
if "Connection error" in str(e):
|
|
pass
|
|
else:
|
|
pytest.fail(f"An exception occurred - {str(e)}")
|
|
|
|
|
|
# asyncio.run(test_async_image_generation_openai())
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aimage_generation_bedrock_with_optional_params():
|
|
try:
|
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
|
response = await litellm.aimage_generation(
|
|
prompt="A cute baby sea otter",
|
|
model="bedrock/stability.stable-diffusion-xl-v1",
|
|
size="256x256",
|
|
)
|
|
print(f"response: {response}")
|
|
except litellm.RateLimitError as e:
|
|
pass
|
|
except litellm.ContentPolicyViolationError:
|
|
pass # Azure randomly raises these errors skip when they occur
|
|
except Exception as e:
|
|
if "Your task failed as a result of our safety system." in str(e):
|
|
pass
|
|
else:
|
|
pytest.fail(f"An exception occurred - {str(e)}")
|