forked from phoenix/litellm-mirror
[Fix] Performance - use in memory cache when downloading images from a url (#5657)
* fix use in memory cache when getting images * fix linting * fix load testing * fix load test size * fix load test size * trigger ci/cd again
This commit is contained in:
parent
cdd7cd4d69
commit
cd8d7ca915
5 changed files with 249 additions and 38 deletions
|
@ -37,6 +37,8 @@ from litellm.types.llms.openai import (
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericImageParsingChunk
|
from litellm.types.utils import GenericImageParsingChunk
|
||||||
|
|
||||||
|
from .image_handling import async_convert_url_to_base64, convert_url_to_base64
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
return " ".join(message["content"] for message in messages)
|
return " ".join(message["content"] for message in messages)
|
||||||
|
@ -703,44 +705,6 @@ def construct_tool_use_system_prompt(
|
||||||
return tool_use_system_prompt
|
return tool_use_system_prompt
|
||||||
|
|
||||||
|
|
||||||
def convert_url_to_base64(url):
|
|
||||||
import base64
|
|
||||||
|
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
|
|
||||||
response = client.get(url)
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if response.status_code == 200:
|
|
||||||
image_bytes = response.content
|
|
||||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
||||||
|
|
||||||
image_type = response.headers.get("Content-Type", None)
|
|
||||||
if image_type is not None:
|
|
||||||
img_type = image_type
|
|
||||||
else:
|
|
||||||
img_type = url.split(".")[-1].lower()
|
|
||||||
if img_type == "jpg" or img_type == "jpeg":
|
|
||||||
img_type = "image/jpeg"
|
|
||||||
elif img_type == "png":
|
|
||||||
img_type = "image/png"
|
|
||||||
elif img_type == "gif":
|
|
||||||
img_type = "image/gif"
|
|
||||||
elif img_type == "webp":
|
|
||||||
img_type = "image/webp"
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Error: Unsupported image format. Format={img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']"
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"data:{img_type};base64,{base64_image}"
|
|
||||||
else:
|
|
||||||
raise Exception(f"Error: Unable to fetch image from URL. url={url}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
|
|
84
litellm/llms/prompt_templates/image_handling.py
Normal file
84
litellm/llms/prompt_templates/image_handling.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
"""
|
||||||
|
Helper functions to handle images passed in messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_IMGS_IN_MEMORY = 10
|
||||||
|
|
||||||
|
in_memory_cache = InMemoryCache(max_size_in_memory=MAX_IMGS_IN_MEMORY)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_image_response(response: Response, url: str) -> str:
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Error: Unable to fetch image from URL. Status code: {response.status_code}, url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
image_bytes = response.content
|
||||||
|
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
image_type = response.headers.get("Content-Type")
|
||||||
|
if image_type is None:
|
||||||
|
img_type = url.split(".")[-1].lower()
|
||||||
|
_img_type = {
|
||||||
|
"jpg": "image/jpeg",
|
||||||
|
"jpeg": "image/jpeg",
|
||||||
|
"png": "image/png",
|
||||||
|
"gif": "image/gif",
|
||||||
|
"webp": "image/webp",
|
||||||
|
}.get(img_type)
|
||||||
|
if _img_type is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Error: Unsupported image format. Format={_img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']"
|
||||||
|
)
|
||||||
|
img_type = _img_type
|
||||||
|
else:
|
||||||
|
img_type = image_type
|
||||||
|
|
||||||
|
result = f"data:{img_type};base64,{base64_image}"
|
||||||
|
in_memory_cache.set_cache(url, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def async_convert_url_to_base64(url: str) -> str:
|
||||||
|
cached_result = in_memory_cache.get_cache(url)
|
||||||
|
if cached_result:
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
client = litellm.module_level_aclient
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
response = await client.get(url)
|
||||||
|
return _process_image_response(response, url)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise Exception(
|
||||||
|
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_url_to_base64(url: str) -> str:
|
||||||
|
cached_result = in_memory_cache.get_cache(url)
|
||||||
|
if cached_result:
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
client = litellm.module_level_client
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
response = client.get(url)
|
||||||
|
return _process_image_response(response, url)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise Exception(
|
||||||
|
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
||||||
|
)
|
|
@ -25,6 +25,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries =3
|
# litellm.num_retries =3
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
|
|
149
tests/load_tests/test_vertex_load_tests.py
Normal file
149
tests/load_tests/test_vertex_load_tests.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import litellm
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
def load_vertex_ai_credentials():
|
||||||
|
# Define the path to the vertex_key.json file
|
||||||
|
print("loading vertex ai credentials")
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
vertex_key_path = filepath + "/vertex_key.json"
|
||||||
|
|
||||||
|
# Read the existing content of the file or create an empty dictionary
|
||||||
|
try:
|
||||||
|
with open(vertex_key_path, "r") as file:
|
||||||
|
# Read the file content
|
||||||
|
print("Read vertexai file path")
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
# If the file is empty or not valid JSON, create an empty dictionary
|
||||||
|
if not content or not content.strip():
|
||||||
|
service_account_key_data = {}
|
||||||
|
else:
|
||||||
|
# Attempt to load the existing JSON content
|
||||||
|
file.seek(0)
|
||||||
|
service_account_key_data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If the file doesn't exist, create an empty dictionary
|
||||||
|
service_account_key_data = {}
|
||||||
|
|
||||||
|
# Update the service_account_key_data with environment variables
|
||||||
|
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||||
|
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||||
|
private_key = private_key.replace("\\n", "\n")
|
||||||
|
service_account_key_data["private_key_id"] = private_key_id
|
||||||
|
service_account_key_data["private_key"] = private_key
|
||||||
|
|
||||||
|
# Create a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||||
|
# Write the updated content to the temporary files
|
||||||
|
json.dump(service_account_key_data, temp_file, indent=2)
|
||||||
|
|
||||||
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_load():
|
||||||
|
try:
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
percentage_diffs = []
|
||||||
|
|
||||||
|
for run in range(3):
|
||||||
|
print(f"\nRun {run + 1}:")
|
||||||
|
|
||||||
|
# Test with text-only message
|
||||||
|
start_time_text = await make_async_calls(message_type="text")
|
||||||
|
print("Done with text-only message test")
|
||||||
|
|
||||||
|
# Test with text + image message
|
||||||
|
start_time_image = await make_async_calls(message_type="image")
|
||||||
|
print("Done with text + image message test")
|
||||||
|
|
||||||
|
# Compare times and calculate percentage difference
|
||||||
|
print(f"Time with text-only message: {start_time_text}")
|
||||||
|
print(f"Time with text + image message: {start_time_image}")
|
||||||
|
|
||||||
|
percentage_diff = (
|
||||||
|
(start_time_image - start_time_text) / start_time_text * 100
|
||||||
|
)
|
||||||
|
percentage_diffs.append(percentage_diff)
|
||||||
|
print(f"Performance difference: {percentage_diff:.2f}%")
|
||||||
|
|
||||||
|
print("percentage_diffs", percentage_diffs)
|
||||||
|
# Calculate average percentage difference
|
||||||
|
avg_percentage_diff = sum(percentage_diffs) / len(percentage_diffs)
|
||||||
|
print(f"\nAverage performance difference: {avg_percentage_diff:.2f}%")
|
||||||
|
|
||||||
|
# Assert that the average difference is not more than 20%
|
||||||
|
assert (
|
||||||
|
avg_percentage_diff < 20
|
||||||
|
), f"Average performance difference of {avg_percentage_diff:.2f}% exceeds 20% threshold"
|
||||||
|
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def make_async_calls(message_type="text"):
|
||||||
|
total_tasks = 3
|
||||||
|
batch_size = 1
|
||||||
|
total_time = 0
|
||||||
|
|
||||||
|
for batch in range(3):
|
||||||
|
tasks = [create_async_task(message_type) for _ in range(batch_size)]
|
||||||
|
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
responses = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
for idx, response in enumerate(responses):
|
||||||
|
print(f"Response from Task {batch * batch_size + idx + 1}: {response}")
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
batch_time = asyncio.get_event_loop().time() - start_time
|
||||||
|
total_time += batch_time
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def create_async_task(message_type):
|
||||||
|
base_url = "https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001"
|
||||||
|
|
||||||
|
if message_type == "text":
|
||||||
|
messages = [{"role": "user", "content": "hi"}]
|
||||||
|
else:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
completion_args = {
|
||||||
|
"model": "vertex_ai/gemini",
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"timeout": 10,
|
||||||
|
"api_base": base_url,
|
||||||
|
}
|
||||||
|
return asyncio.create_task(litellm.acompletion(**completion_args))
|
13
tests/load_tests/vertex_key.json
Normal file
13
tests/load_tests/vertex_key.json
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"type": "service_account",
|
||||||
|
"project_id": "adroit-crow-413218",
|
||||||
|
"private_key_id": "",
|
||||||
|
"private_key": "",
|
||||||
|
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
|
||||||
|
"client_id": "104886546564708740969",
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||||
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
|
||||||
|
"universe_domain": "googleapis.com"
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue