Merge pull request #2840 from BerriAI/litellm_return_cache_key_responses

[FEAT] Proxy - Delete Cache Keys + return cache key in responses
This commit is contained in:
Ishaan Jaff 2024-04-04 11:52:52 -07:00 committed by GitHub
commit 1119cc49a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 5 deletions

View file

@ -347,6 +347,12 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
raise e raise e
async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
def client_list(self): def client_list(self):
client_list = self.redis_client.client_list() client_list = self.redis_client.client_list()
return client_list return client_list
@ -1408,6 +1414,11 @@ class Cache:
return await self.cache.ping() return await self.cache.ping()
return None return None
async def delete_cache_keys(self, keys):
if hasattr(self.cache, "delete_cache_keys"):
return await self.cache.delete_cache_keys(keys)
return None
async def disconnect(self): async def disconnect(self):
if hasattr(self.cache, "disconnect"): if hasattr(self.cache, "disconnect"):
await self.cache.disconnect() await self.cache.disconnect()

View file

@ -3437,15 +3437,18 @@ async def chat_completion(
# Post Call Processing # Post Call Processing
data["litellm_status"] = "success" # used for alerting data["litellm_status"] = "success" # used for alerting
if hasattr(response, "_hidden_params"):
model_id = response._hidden_params.get("model_id", None) or "" hidden_params = getattr(response, "_hidden_params", {}) or {}
else: model_id = hidden_params.get("model_id", None) or ""
model_id = "" cache_key = hidden_params.get("cache_key", None) or ""
if ( if (
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id} custom_headers = {
"x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
}
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict response=response, user_api_key_dict=user_api_key_dict
) )
@ -3456,6 +3459,7 @@ async def chat_completion(
) )
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
@ -8206,6 +8210,51 @@ async def cache_ping():
) )
@router.post(
"/cache/delete",
tags=["caching"],
dependencies=[Depends(user_api_key_auth)],
)
async def cache_delete(request: Request):
"""
Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers
Parameters:
- **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]}
```shell
curl -X POST "http://0.0.0.0:4000/cache/delete" \
-H "Authorization: Bearer sk-1234" \
-d '{"keys": ["key1", "key2"]}'
```
"""
try:
if litellm.cache is None:
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
request_data = await request.json()
keys = request_data.get("keys", None)
if litellm.cache.type == "redis":
await litellm.cache.delete_cache_keys(keys=keys)
return {
"status": "success",
}
else:
raise HTTPException(
status_code=500,
detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported",
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Cache Delete Failed({str(e)})",
)
@router.get( @router.get(
"/cache/redis/info", "/cache/redis/info",
tags=["caching"], tags=["caching"],

View file

@ -3132,6 +3132,8 @@ def client(original_function):
target=logging_obj.success_handler, target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit), args=(cached_result, start_time, end_time, cache_hit),
).start() ).start()
cache_key = kwargs.get("preset_cache_key", None)
cached_result._hidden_params["cache_key"] = cache_key
return cached_result return cached_result
elif ( elif (
call_type == CallTypes.aembedding.value call_type == CallTypes.aembedding.value