From cd8d7ca9156a5fc2510db1ef0d43956d3239eccf Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 13 Sep 2024 07:23:42 -0700 Subject: [PATCH] [Fix] Performance - use in memory cache when downloading images from a url (#5657) * fix use in memory cache when getting images * fix linting * fix load testing * fix load test size * fix load test size * trigger ci/cd again --- litellm/llms/prompt_templates/factory.py | 40 +---- .../llms/prompt_templates/image_handling.py | 84 ++++++++++ litellm/tests/test_completion.py | 1 + tests/load_tests/test_vertex_load_tests.py | 149 ++++++++++++++++++ tests/load_tests/vertex_key.json | 13 ++ 5 files changed, 249 insertions(+), 38 deletions(-) create mode 100644 litellm/llms/prompt_templates/image_handling.py create mode 100644 tests/load_tests/test_vertex_load_tests.py create mode 100644 tests/load_tests/vertex_key.json diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 3fc654d25..e555d873b 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -37,6 +37,8 @@ from litellm.types.llms.openai import ( ) from litellm.types.utils import GenericImageParsingChunk +from .image_handling import async_convert_url_to_base64, convert_url_to_base64 + def default_pt(messages): return " ".join(message["content"] for message in messages) @@ -703,44 +705,6 @@ def construct_tool_use_system_prompt( return tool_use_system_prompt -def convert_url_to_base64(url): - import base64 - - client = HTTPHandler(concurrent_limit=1) - for _ in range(3): - try: - - response = client.get(url) - break - except: - pass - if response.status_code == 200: - image_bytes = response.content - base64_image = base64.b64encode(image_bytes).decode("utf-8") - - image_type = response.headers.get("Content-Type", None) - if image_type is not None: - img_type = image_type - else: - img_type = url.split(".")[-1].lower() - if img_type == "jpg" or img_type == "jpeg": - img_type = "image/jpeg" - elif img_type == "png": - img_type = "image/png" - elif img_type == "gif": - img_type = "image/gif" - elif img_type == "webp": - img_type = "image/webp" - else: - raise Exception( - f"Error: Unsupported image format. Format={img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']" - ) - - return f"data:{img_type};base64,{base64_image}" - else: - raise Exception(f"Error: Unable to fetch image from URL. url={url}") - - def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk: """ Input: diff --git a/litellm/llms/prompt_templates/image_handling.py b/litellm/llms/prompt_templates/image_handling.py new file mode 100644 index 000000000..90db3dedc --- /dev/null +++ b/litellm/llms/prompt_templates/image_handling.py @@ -0,0 +1,84 @@ +""" +Helper functions to handle images passed in messages +""" + +import base64 + +from httpx import Response + +import litellm +from litellm.caching import InMemoryCache +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) + +MAX_IMGS_IN_MEMORY = 10 + +in_memory_cache = InMemoryCache(max_size_in_memory=MAX_IMGS_IN_MEMORY) + + +def _process_image_response(response: Response, url: str) -> str: + if response.status_code != 200: + raise Exception( + f"Error: Unable to fetch image from URL. Status code: {response.status_code}, url={url}" + ) + + image_bytes = response.content + base64_image = base64.b64encode(image_bytes).decode("utf-8") + + image_type = response.headers.get("Content-Type") + if image_type is None: + img_type = url.split(".")[-1].lower() + _img_type = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "webp": "image/webp", + }.get(img_type) + if _img_type is None: + raise Exception( + f"Error: Unsupported image format. Format={_img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']" + ) + img_type = _img_type + else: + img_type = image_type + + result = f"data:{img_type};base64,{base64_image}" + in_memory_cache.set_cache(url, result) + return result + + +async def async_convert_url_to_base64(url: str) -> str: + cached_result = in_memory_cache.get_cache(url) + if cached_result: + return cached_result + + client = litellm.module_level_aclient + for _ in range(3): + try: + response = await client.get(url) + return _process_image_response(response, url) + except: + pass + raise Exception( + f"Error: Unable to fetch image from URL after 3 attempts. url={url}" + ) + + +def convert_url_to_base64(url: str) -> str: + cached_result = in_memory_cache.get_cache(url) + if cached_result: + return cached_result + + client = litellm.module_level_client + for _ in range(3): + try: + response = client.get(url) + return _process_image_response(response, url) + except: + pass + raise Exception( + f"Error: Unable to fetch image from URL after 3 attempts. url={url}" + ) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index eda2b1595..457f645aa 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -25,6 +25,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt # litellm.num_retries =3 + litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" diff --git a/tests/load_tests/test_vertex_load_tests.py b/tests/load_tests/test_vertex_load_tests.py new file mode 100644 index 000000000..dcb69a62c --- /dev/null +++ b/tests/load_tests/test_vertex_load_tests.py @@ -0,0 +1,149 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import litellm +import pytest +import time +import json +import tempfile +from dotenv import load_dotenv + + +def load_vertex_ai_credentials(): + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/vertex_key.json" + + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + # Write the updated content to the temporary files + json.dump(service_account_key_data, temp_file, indent=2) + + # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) + + +@pytest.mark.asyncio +async def test_vertex_load(): + try: + load_vertex_ai_credentials() + percentage_diffs = [] + + for run in range(3): + print(f"\nRun {run + 1}:") + + # Test with text-only message + start_time_text = await make_async_calls(message_type="text") + print("Done with text-only message test") + + # Test with text + image message + start_time_image = await make_async_calls(message_type="image") + print("Done with text + image message test") + + # Compare times and calculate percentage difference + print(f"Time with text-only message: {start_time_text}") + print(f"Time with text + image message: {start_time_image}") + + percentage_diff = ( + (start_time_image - start_time_text) / start_time_text * 100 + ) + percentage_diffs.append(percentage_diff) + print(f"Performance difference: {percentage_diff:.2f}%") + + print("percentage_diffs", percentage_diffs) + # Calculate average percentage difference + avg_percentage_diff = sum(percentage_diffs) / len(percentage_diffs) + print(f"\nAverage performance difference: {avg_percentage_diff:.2f}%") + + # Assert that the average difference is not more than 20% + assert ( + avg_percentage_diff < 20 + ), f"Average performance difference of {avg_percentage_diff:.2f}% exceeds 20% threshold" + + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred - {e}") + + +async def make_async_calls(message_type="text"): + total_tasks = 3 + batch_size = 1 + total_time = 0 + + for batch in range(3): + tasks = [create_async_task(message_type) for _ in range(batch_size)] + + start_time = asyncio.get_event_loop().time() + responses = await asyncio.gather(*tasks) + + for idx, response in enumerate(responses): + print(f"Response from Task {batch * batch_size + idx + 1}: {response}") + + await asyncio.sleep(1) + + batch_time = asyncio.get_event_loop().time() - start_time + total_time += batch_time + + return total_time + + +def create_async_task(message_type): + base_url = "https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001" + + if message_type == "text": + messages = [{"role": "user", "content": "hi"}] + else: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" + }, + }, + ], + } + ] + + completion_args = { + "model": "vertex_ai/gemini", + "messages": messages, + "max_tokens": 5, + "temperature": 0.7, + "timeout": 10, + "api_base": base_url, + } + return asyncio.create_task(litellm.acompletion(**completion_args)) diff --git a/tests/load_tests/vertex_key.json b/tests/load_tests/vertex_key.json new file mode 100644 index 000000000..e2fd8512b --- /dev/null +++ b/tests/load_tests/vertex_key.json @@ -0,0 +1,13 @@ +{ + "type": "service_account", + "project_id": "adroit-crow-413218", + "private_key_id": "", + "private_key": "", + "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", + "client_id": "104886546564708740969", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +}