diff --git a/litellm/caching.py b/litellm/caching.py index 9df95f199..ed856f86f 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -129,6 +129,16 @@ class RedisCache(BaseCache): f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" ) + async def async_scan_iter(self, pattern: str, count: int = 100) -> list: + keys = [] + _redis_client = self.init_async_client() + async with _redis_client as redis_client: + async for key in redis_client.scan_iter(match=pattern + "*", count=count): + keys.append(key) + if len(keys) >= count: + break + return keys + async def async_set_cache(self, key, value, **kwargs): _redis_client = self.init_async_client() async with _redis_client as redis_client: @@ -140,6 +150,9 @@ class RedisCache(BaseCache): await redis_client.set( name=key, value=json.dumps(value), ex=ttl, get=True ) + print_verbose( + f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) except Exception as e: # NON blocking - notify users Redis is throwing an exception print_verbose( @@ -172,8 +185,6 @@ class RedisCache(BaseCache): return results except Exception as e: print_verbose(f"Error occurred in pipeline write - {str(e)}") - # NON blocking - notify users Redis is throwing an exception - logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) def _get_cache_logic(self, cached_response: Any): """ @@ -208,7 +219,7 @@ class RedisCache(BaseCache): _redis_client = self.init_async_client() async with _redis_client as redis_client: try: - print_verbose(f"Get Redis Cache: key: {key}") + print_verbose(f"Get Async Redis Cache: key: {key}") cached_response = await redis_client.get(key) print_verbose( f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" @@ -217,8 +228,39 @@ class RedisCache(BaseCache): return response except Exception as e: # NON blocking - notify users Redis is throwing an exception - traceback.print_exc() - logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + print_verbose( + f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" + ) + + async def async_get_cache_pipeline(self, key_list) -> dict: + """ + Use Redis for bulk read operations + """ + _redis_client = await self.init_async_client() + key_value_dict = {} + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + # Queue the get operations in the pipeline for all keys. + for cache_key in key_list: + pipe.get(cache_key) # Queue GET command in pipeline + + # Execute the pipeline and await the results. + results = await pipe.execute() + + # Associate the results back with their keys. + # 'results' is a list of values corresponding to the order of keys in 'key_list'. + key_value_dict = dict(zip(key_list, results)) + + decoded_results = { + k.decode("utf-8"): self._get_cache_logic(v) + for k, v in key_value_dict.items() + } + + return decoded_results + except Exception as e: + print_verbose(f"Error occurred in pipeline read - {str(e)}") + return key_value_dict def flush_cache(self): self.redis_client.flushall() @@ -1001,6 +1043,10 @@ class Cache: if self.namespace is not None: hash_hex = f"{self.namespace}:{hash_hex}" print_verbose(f"Hashed Key with Namespace: {hash_hex}") + elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None: + _namespace = kwargs.get("metadata", {}).get("redis_namespace", None) + hash_hex = f"{_namespace}:{hash_hex}" + print_verbose(f"Hashed Key with Namespace: {hash_hex}") return hash_hex def generate_streaming_content(self, content): diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index aab9b3d5c..1c41d79fc 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -9,6 +9,12 @@ model_list: model: gpt-3.5-turbo-1106 api_key: os.environ/OPENAI_API_KEY +litellm_settings: + cache: true + cache_params: + type: redis + callbacks: ["batch_redis_requests"] + general_settings: master_key: sk-1234 - database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file + # database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file diff --git a/litellm/proxy/hooks/batch_redis_get.py b/litellm/proxy/hooks/batch_redis_get.py new file mode 100644 index 000000000..71588c9d4 --- /dev/null +++ b/litellm/proxy/hooks/batch_redis_get.py @@ -0,0 +1,124 @@ +# What this does? +## Gets a key's redis cache, and store it in memory for 1 minute. +## This reduces the number of REDIS GET requests made during high-traffic by the proxy. +### [BETA] this is in Beta. And might change. + +from typing import Optional, Literal +import litellm +from litellm.caching import DualCache, RedisCache, InMemoryCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_proxy_logger +from fastapi import HTTPException +import json, traceback + + +class _PROXY_BatchRedisRequests(CustomLogger): + # Class variables or attributes + in_memory_cache: Optional[InMemoryCache] = None + + def __init__(self): + litellm.cache.async_get_cache = ( + self.async_get_cache + ) # map the litellm 'get_cache' function to our custom function + + def print_verbose( + self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG" + ): + if debug_level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + elif debug_level == "INFO": + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose is True: + print(print_statement) # noqa + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + """ + Get the user key + + Check if a key starting with `litellm:: 0: + key_value_dict = ( + await litellm.cache.cache.async_get_cache_pipeline( + key_list=keys + ) + ) + + ## Add to cache + if len(key_value_dict.items()) > 0: + await cache.in_memory_cache.async_set_cache_pipeline( + cache_list=list(key_value_dict.items()), ttl=60 + ) + ## Set cache namespace if it's a miss + data["metadata"]["redis_namespace"] = cache_key_name + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() + + async def async_get_cache(self, *args, **kwargs): + """ + - Check if the cache key is in-memory + + - Else return None + """ + try: # never block execution + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = litellm.cache.get_cache_key( + *args, **kwargs + ) # returns ":" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic + if cache_key is not None and self.in_memory_cache is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = self.in_memory_cache.get_cache( + cache_key, *args, **kwargs + ) + return litellm.cache._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception as e: + return None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 52dac7524..7ad08eac8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1798,6 +1798,16 @@ class ProxyConfig: _ENTERPRISE_PromptInjectionDetection() ) imported_list.append(prompt_injection_detection_obj) + elif ( + isinstance(callback, str) + and callback == "batch_redis_requests" + ): + from litellm.proxy.hooks.batch_redis_get import ( + _PROXY_BatchRedisRequests, + ) + + batch_redis_obj = _PROXY_BatchRedisRequests() + imported_list.append(batch_redis_obj) else: imported_list.append( get_instance_fn( diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 07d39b086..aa0681c61 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -474,11 +474,10 @@ def test_redis_cache_completion_stream(): # test_redis_cache_completion_stream() -def test_redis_cache_acompletion_stream(): - import asyncio - +@pytest.mark.asyncio +async def test_redis_cache_acompletion_stream(): try: - litellm.set_verbose = False + litellm.set_verbose = True random_word = generate_random_word() messages = [ { @@ -496,37 +495,31 @@ def test_redis_cache_acompletion_stream(): response_1_content = "" response_2_content = "" - async def call1(): - nonlocal response_1_content - response1 = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response1: - response_1_content += chunk.choices[0].delta.content or "" - print(response_1_content) + response1 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response1: + response_1_content += chunk.choices[0].delta.content or "" + print(response_1_content) - asyncio.run(call1()) time.sleep(0.5) print("\n\n Response 1 content: ", response_1_content, "\n\n") - async def call2(): - nonlocal response_2_content - response2 = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response2: - response_2_content += chunk.choices[0].delta.content or "" - print(response_2_content) + response2 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response2: + response_2_content += chunk.choices[0].delta.content or "" + print(response_2_content) - asyncio.run(call2()) print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) assert ( @@ -536,14 +529,15 @@ def test_redis_cache_acompletion_stream(): litellm.success_callback = [] litellm._async_success_callback = [] except Exception as e: - print(e) + print(f"{str(e)}\n\n{traceback.format_exc()}") raise e # test_redis_cache_acompletion_stream() -def test_redis_cache_acompletion_stream_bedrock(): +@pytest.mark.asyncio +async def test_redis_cache_acompletion_stream_bedrock(): import asyncio try: @@ -565,39 +559,33 @@ def test_redis_cache_acompletion_stream_bedrock(): response_1_content = "" response_2_content = "" - async def call1(): - nonlocal response_1_content - response1 = await litellm.acompletion( - model="bedrock/anthropic.claude-v2", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response1: - print(chunk) - response_1_content += chunk.choices[0].delta.content or "" - print(response_1_content) + response1 = await litellm.acompletion( + model="bedrock/anthropic.claude-v2", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response1: + print(chunk) + response_1_content += chunk.choices[0].delta.content or "" + print(response_1_content) - asyncio.run(call1()) time.sleep(0.5) print("\n\n Response 1 content: ", response_1_content, "\n\n") - async def call2(): - nonlocal response_2_content - response2 = await litellm.acompletion( - model="bedrock/anthropic.claude-v2", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response2: - print(chunk) - response_2_content += chunk.choices[0].delta.content or "" - print(response_2_content) + response2 = await litellm.acompletion( + model="bedrock/anthropic.claude-v2", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response2: + print(chunk) + response_2_content += chunk.choices[0].delta.content or "" + print(response_2_content) - asyncio.run(call2()) print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) assert ( @@ -612,8 +600,8 @@ def test_redis_cache_acompletion_stream_bedrock(): raise e -@pytest.mark.skip(reason="AWS Suspended Account") -def test_s3_cache_acompletion_stream_azure(): +@pytest.mark.asyncio +async def test_s3_cache_acompletion_stream_azure(): import asyncio try: @@ -637,41 +625,35 @@ def test_s3_cache_acompletion_stream_azure(): response_1_created = "" response_2_created = "" - async def call1(): - nonlocal response_1_content, response_1_created - response1 = await litellm.acompletion( - model="azure/chatgpt-v-2", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response1: - print(chunk) - response_1_created = chunk.created - response_1_content += chunk.choices[0].delta.content or "" - print(response_1_content) + response1 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response1: + print(chunk) + response_1_created = chunk.created + response_1_content += chunk.choices[0].delta.content or "" + print(response_1_content) - asyncio.run(call1()) time.sleep(0.5) print("\n\n Response 1 content: ", response_1_content, "\n\n") - async def call2(): - nonlocal response_2_content, response_2_created - response2 = await litellm.acompletion( - model="azure/chatgpt-v-2", - messages=messages, - max_tokens=40, - temperature=1, - stream=True, - ) - async for chunk in response2: - print(chunk) - response_2_content += chunk.choices[0].delta.content or "" - response_2_created = chunk.created - print(response_2_content) + response2 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) + async for chunk in response2: + print(chunk) + response_2_content += chunk.choices[0].delta.content or "" + response_2_created = chunk.created + print(response_2_content) - asyncio.run(call2()) print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 0a8f7b941..b2e2b7d22 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -97,27 +97,23 @@ class TmpFunction: ) -def test_async_chat_openai_stream(): +@pytest.mark.asyncio +async def test_async_chat_openai_stream(): try: tmp_function = TmpFunction() litellm.set_verbose = True litellm.success_callback = [tmp_function.async_test_logging_fn] complete_streaming_response = "" - async def call_gpt(): - nonlocal complete_streaming_response - response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], - stream=True, - ) - async for chunk in response: - complete_streaming_response += ( - chunk["choices"][0]["delta"]["content"] or "" - ) - print(complete_streaming_response) + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + stream=True, + ) + async for chunk in response: + complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" + print(complete_streaming_response) - asyncio.run(call_gpt()) complete_streaming_response = complete_streaming_response.strip("'") response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][ "message" @@ -130,7 +126,7 @@ def test_async_chat_openai_stream(): assert tmp_function.async_success == True except Exception as e: print(e) - pytest.fail(f"An error occurred - {str(e)}") + pytest.fail(f"An error occurred - {str(e)}\n\n{traceback.format_exc()}") # test_async_chat_openai_stream() diff --git a/litellm/utils.py b/litellm/utils.py index 7ad4107a9..95b18421f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -72,7 +72,7 @@ from .integrations.litedebugger import LiteDebugger from .proxy._types import KeyManagementSystem from openai import OpenAIError as OriginalError from openai._models import BaseModel as OpenAIObject -from .caching import S3Cache, RedisSemanticCache +from .caching import S3Cache, RedisSemanticCache, RedisCache from .exceptions import ( AuthenticationError, BadRequestError, @@ -1795,7 +1795,12 @@ class Logging: ) result = kwargs["async_complete_streaming_response"] # only add to cache once we have a complete streaming response - litellm.cache.add_cache(result, **kwargs) + if litellm.cache is not None and not isinstance( + litellm.cache.cache, S3Cache + ): + await litellm.cache.async_add_cache(result, **kwargs) + else: + litellm.cache.add_cache(result, **kwargs) if isinstance(callback, CustomLogger): # custom logger class print_verbose( f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}" @@ -2806,7 +2811,9 @@ def client(original_function): ): if len(cached_result) == 1 and cached_result[0] is None: cached_result = None - elif isinstance(litellm.cache.cache, RedisSemanticCache): + elif isinstance( + litellm.cache.cache, RedisSemanticCache + ) or isinstance(litellm.cache.cache, RedisCache): preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) kwargs["preset_cache_key"] = ( preset_cache_key # for streaming calls, we need to pass the preset_cache_key