Merge pull request #2542 from BerriAI/litellm_redis_perf_improvements

89% Caching improvement - Async Redis completion calls + batch redis GET requests for a given key + call type
This commit is contained in:
Krish Dholakia 2024-03-15 18:58:36 -07:00 committed by GitHub
commit 4969ae0e9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 287 additions and 116 deletions

View file

@ -129,6 +129,16 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
keys = []
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(match=pattern + "*", count=count):
keys.append(key)
if len(keys) >= count:
break
return keys
async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
@ -140,6 +150,9 @@ class RedisCache(BaseCache):
await redis_client.set(
name=key, value=json.dumps(value), ex=ttl, get=True
)
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose(
@ -172,8 +185,6 @@ class RedisCache(BaseCache):
return results
except Exception as e:
print_verbose(f"Error occurred in pipeline write - {str(e)}")
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
@ -208,7 +219,7 @@ class RedisCache(BaseCache):
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
@ -217,8 +228,39 @@ class RedisCache(BaseCache):
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_get_cache_pipeline(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Queue the get operations in the pipeline for all keys.
for cache_key in key_list:
pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results.
results = await pipe.execute()
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {
k.decode("utf-8"): self._get_cache_logic(v)
for k, v in key_value_dict.items()
}
return decoded_results
except Exception as e:
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
def flush_cache(self):
self.redis_client.flushall()
@ -1001,6 +1043,10 @@ class Cache:
if self.namespace is not None:
hash_hex = f"{self.namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
hash_hex = f"{_namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
return hash_hex
def generate_streaming_content(self, content):

View file

@ -9,6 +9,12 @@ model_list:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
cache: true
cache_params:
type: redis
callbacks: ["batch_redis_requests"]
general_settings:
master_key: sk-1234
database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require"
# database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require"

View file

@ -0,0 +1,124 @@
# What this does?
## Gets a key's redis cache, and store it in memory for 1 minute.
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache, RedisCache, InMemoryCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
import json, traceback
class _PROXY_BatchRedisRequests(CustomLogger):
# Class variables or attributes
in_memory_cache: Optional[InMemoryCache] = None
def __init__(self):
litellm.cache.async_get_cache = (
self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
def print_verbose(
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
):
if debug_level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
elif debug_level == "INFO":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
"""
Get the user key
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
If no, then get relevant cache from redis
"""
api_key = user_api_key_dict.api_key
cache_key_name = f"litellm:{api_key}:{call_type}"
self.in_memory_cache = cache.in_memory_cache
key_value_dict = {}
in_memory_cache_exists = False
for key in cache.in_memory_cache.cache_dict.keys():
if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True
if in_memory_cache_exists == False and litellm.cache is not None:
"""
- Check if `litellm.Cache` is redis
- Get the relevant values
"""
if litellm.cache.type is not None and isinstance(
litellm.cache.cache, RedisCache
):
# Initialize an empty list to store the keys
keys = []
self.print_verbose(f"cache_key_name: {cache_key_name}")
# Use the SCAN iterator to fetch keys matching the pattern
keys = await litellm.cache.cache.async_scan_iter(
pattern=cache_key_name, count=100
)
# If you need the truly "last" based on time or another criteria,
# ensure your key naming or storage strategy allows this determination
# Here you would sort or filter the keys as needed based on your strategy
self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0:
key_value_dict = (
await litellm.cache.cache.async_get_cache_pipeline(
key_list=keys
)
)
## Add to cache
if len(key_value_dict.items()) > 0:
await cache.in_memory_cache.async_set_cache_pipeline(
cache_list=list(key_value_dict.items()), ttl=60
)
## Set cache namespace if it's a miss
data["metadata"]["redis_namespace"] = cache_key_name
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, *args, **kwargs):
"""
- Check if the cache key is in-memory
- Else return None
"""
try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = litellm.cache.get_cache_key(
*args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if cache_key is not None and self.in_memory_cache is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs
)
return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
return None

View file

@ -1798,6 +1798,16 @@ class ProxyConfig:
_ENTERPRISE_PromptInjectionDetection()
)
imported_list.append(prompt_injection_detection_obj)
elif (
isinstance(callback, str)
and callback == "batch_redis_requests"
):
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
else:
imported_list.append(
get_instance_fn(

View file

@ -474,11 +474,10 @@ def test_redis_cache_completion_stream():
# test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream():
import asyncio
@pytest.mark.asyncio
async def test_redis_cache_acompletion_stream():
try:
litellm.set_verbose = False
litellm.set_verbose = True
random_word = generate_random_word()
messages = [
{
@ -496,37 +495,31 @@ def test_redis_cache_acompletion_stream():
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert (
@ -536,14 +529,15 @@ def test_redis_cache_acompletion_stream():
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
print(f"{str(e)}\n\n{traceback.format_exc()}")
raise e
# test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock():
@pytest.mark.asyncio
async def test_redis_cache_acompletion_stream_bedrock():
import asyncio
try:
@ -565,39 +559,33 @@ def test_redis_cache_acompletion_stream_bedrock():
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert (
@ -612,8 +600,8 @@ def test_redis_cache_acompletion_stream_bedrock():
raise e
@pytest.mark.skip(reason="AWS Suspended Account")
def test_s3_cache_acompletion_stream_azure():
@pytest.mark.asyncio
async def test_s3_cache_acompletion_stream_azure():
import asyncio
try:
@ -637,41 +625,35 @@ def test_s3_cache_acompletion_stream_azure():
response_1_created = ""
response_2_created = ""
async def call1():
nonlocal response_1_content, response_1_created
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_created = chunk.created
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_created = chunk.created
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content, response_2_created
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
response_2_created = chunk.created
print(response_2_content)
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
response_2_created = chunk.created
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)

View file

@ -97,27 +97,23 @@ class TmpFunction:
)
def test_async_chat_openai_stream():
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
tmp_function = TmpFunction()
litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = ""
async def call_gpt():
nonlocal complete_streaming_response
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
complete_streaming_response += (
chunk["choices"][0]["delta"]["content"] or ""
)
print(complete_streaming_response)
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response)
asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
"message"
@ -130,7 +126,7 @@ def test_async_chat_openai_stream():
assert tmp_function.async_success == True
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
pytest.fail(f"An error occurred - {str(e)}\n\n{traceback.format_exc()}")
# test_async_chat_openai_stream()

View file

@ -72,7 +72,7 @@ from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache, RedisSemanticCache
from .caching import S3Cache, RedisSemanticCache, RedisCache
from .exceptions import (
AuthenticationError,
BadRequestError,
@ -1795,7 +1795,12 @@ class Logging:
)
result = kwargs["async_complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if litellm.cache is not None and not isinstance(
litellm.cache.cache, S3Cache
):
await litellm.cache.async_add_cache(result, **kwargs)
else:
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
@ -2806,7 +2811,9 @@ def client(original_function):
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache):
elif isinstance(
litellm.cache.cache, RedisSemanticCache
) or isinstance(litellm.cache.cache, RedisCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key