feat(proxy_server.py): enable cache controls per key + no-store cache flag

This commit is contained in:
Krrish Dholakia 2024-01-30 20:46:50 -08:00
parent 37de964da4
commit f9acad87dc
8 changed files with 108 additions and 42 deletions

View file

@ -5,6 +5,15 @@ from datetime import datetime
import uuid, json, sys, os
def hash_token(token: str):
import hashlib
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
@ -137,6 +146,7 @@ class GenerateRequestBase(LiteLLMBase):
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
allowed_cache_controls: Optional[list] = []
class GenerateKeyRequest(GenerateRequestBase):
@ -177,25 +187,6 @@ class UpdateKeyRequest(GenerateKeyRequest):
metadata: Optional[dict] = None
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
config: dict = {}
spend: Optional[float] = 0
max_budget: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
metadata: dict = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
class DeleteKeyRequest(LiteLLMBase):
keys: List
@ -320,22 +311,39 @@ class ConfigYAML(LiteLLMBase):
class LiteLLM_VerificationToken(LiteLLMBase):
token: str
token: Optional[str] = None
key_name: Optional[str] = None
key_alias: Optional[str] = None
spend: float = 0.0
max_budget: Optional[float] = None
expires: Union[str, None]
models: List[str]
aliases: Dict[str, str] = {}
config: Dict[str, str] = {}
user_id: Union[str, None]
max_parallel_requests: Union[int, None]
metadata: Dict[str, str] = {}
expires: Optional[str] = None
models: List = []
aliases: Dict = {}
config: Dict = {}
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Dict = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
class UserAPIKeyAuth(
LiteLLM_VerificationToken
): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
@root_validator(pre=True)
def check_api_key(cls, values):
if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))})
return values
class LiteLLM_Config(LiteLLMBase):

View file

@ -0,0 +1,55 @@
# What this does?
## Checks if key is allowed to use the cache controls passed in to the completion() call
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
import json, traceback
class CacheControlCheck(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
self.print_verbose(f"Inside Cache Control Check Pre-Call Hook")
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
if (allowed_cache_controls is None) or (
len(allowed_cache_controls) == 0
): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
return
if data.get("cache", None) is None:
return
cache_args = data.get("cache", None)
if isinstance(cache_args, dict):
for k, v in cache_args.items():
if k not in allowed_cache_controls:
raise HTTPException(
status_code=403,
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
)
else: # invalid cache
return
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()

View file

@ -1266,6 +1266,7 @@ async def generate_key_helper_fn(
query_type: Literal["insert_data", "update_data"] = "insert_data",
update_key_values: Optional[dict] = None,
key_alias: Optional[str] = None,
allowed_cache_controls: Optional[list] = [],
):
global prisma_client, custom_db_client
@ -1320,6 +1321,7 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4())
tpm_limit = tpm_limit
rpm_limit = rpm_limit
allowed_cache_controls = allowed_cache_controls
if type(team_id) is not str:
team_id = str(team_id)
try:
@ -1336,6 +1338,7 @@ async def generate_key_helper_fn(
"rpm_limit": rpm_limit,
"budget_duration": budget_duration,
"budget_reset_at": reset_at,
"allowed_cache_controls": allowed_cache_controls,
}
key_data = {
"token": token,
@ -1354,6 +1357,7 @@ async def generate_key_helper_fn(
"rpm_limit": rpm_limit,
"budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at,
"allowed_cache_controls": allowed_cache_controls,
}
if general_settings.get("allow_user_auth", False) == True:
key_data["key_name"] = f"sk-...{token[-4:]}"

View file

@ -20,6 +20,7 @@ model LiteLLM_UserTable {
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
}
// Generate Tokens for Proxy
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
max_budget Float?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
}
// store proxy config.yaml

View file

@ -10,6 +10,7 @@ from litellm.proxy._types import (
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
@ -42,6 +43,7 @@ class ProxyLogging:
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
self.cache_control_check = CacheControlCheck()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
pass
@ -57,6 +59,7 @@ class ProxyLogging:
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)

View file

@ -2217,7 +2217,7 @@ def client(original_function):
litellm.cache is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
) and (kwargs.get("cache", {}).get("no-store", False) != True):
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
@ -2430,9 +2430,12 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if (
litellm.cache is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
(litellm.cache is not None)
and (
str(original_function.__name__)
in litellm.cache.supported_call_types
)
and (kwargs.get("cache", {}).get("no-store", False) != True)
):
if isinstance(result, litellm.ModelResponse) or isinstance(
result, litellm.EmbeddingResponse

View file

@ -20,6 +20,7 @@ model LiteLLM_UserTable {
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
}
// Generate Tokens for Proxy
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
max_budget Float?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
}
// store proxy config.yaml

View file

@ -351,21 +351,10 @@ async def test_key_info_spend_values_sagemaker():
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key, model="sagemaker-completion-model"
)
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# )
# response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_key_info_spend > 0
# assert rounded_response_cost == rounded_key_info_spend