mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat(proxy_server.py): enable cache controls per key + no-store cache flag
This commit is contained in:
parent
37de964da4
commit
f9acad87dc
8 changed files with 108 additions and 42 deletions
|
@ -5,6 +5,15 @@ from datetime import datetime
|
||||||
import uuid, json, sys, os
|
import uuid, json, sys, os
|
||||||
|
|
||||||
|
|
||||||
|
def hash_token(token: str):
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
# Hash the string using SHA-256
|
||||||
|
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
return hashed_token
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(BaseModel):
|
class LiteLLMBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Implements default functions, all pydantic objects should have.
|
Implements default functions, all pydantic objects should have.
|
||||||
|
@ -137,6 +146,7 @@ class GenerateRequestBase(LiteLLMBase):
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
|
allowed_cache_controls: Optional[list] = []
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyRequest(GenerateRequestBase):
|
class GenerateKeyRequest(GenerateRequestBase):
|
||||||
|
@ -177,25 +187,6 @@ class UpdateKeyRequest(GenerateKeyRequest):
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
|
||||||
"""
|
|
||||||
Return the row in the db
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
models: list = []
|
|
||||||
aliases: dict = {}
|
|
||||||
config: dict = {}
|
|
||||||
spend: Optional[float] = 0
|
|
||||||
max_budget: Optional[float] = None
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
max_parallel_requests: Optional[int] = None
|
|
||||||
duration: str = "1h"
|
|
||||||
metadata: dict = {}
|
|
||||||
tpm_limit: Optional[int] = None
|
|
||||||
rpm_limit: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteKeyRequest(LiteLLMBase):
|
class DeleteKeyRequest(LiteLLMBase):
|
||||||
keys: List
|
keys: List
|
||||||
|
|
||||||
|
@ -320,22 +311,39 @@ class ConfigYAML(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_VerificationToken(LiteLLMBase):
|
class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
token: str
|
token: Optional[str] = None
|
||||||
key_name: Optional[str] = None
|
key_name: Optional[str] = None
|
||||||
key_alias: Optional[str] = None
|
key_alias: Optional[str] = None
|
||||||
spend: float = 0.0
|
spend: float = 0.0
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
expires: Union[str, None]
|
expires: Optional[str] = None
|
||||||
models: List[str]
|
models: List = []
|
||||||
aliases: Dict[str, str] = {}
|
aliases: Dict = {}
|
||||||
config: Dict[str, str] = {}
|
config: Dict = {}
|
||||||
user_id: Union[str, None]
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Union[int, None]
|
max_parallel_requests: Optional[int] = None
|
||||||
metadata: Dict[str, str] = {}
|
metadata: Dict = {}
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
|
allowed_cache_controls: Optional[list] = []
|
||||||
|
|
||||||
|
|
||||||
|
class UserAPIKeyAuth(
|
||||||
|
LiteLLM_VerificationToken
|
||||||
|
): # the expected response object for user api key auth
|
||||||
|
"""
|
||||||
|
Return the row in the db
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def check_api_key(cls, values):
|
||||||
|
if values.get("api_key") is not None:
|
||||||
|
values.update({"token": hash_token(values.get("api_key"))})
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_Config(LiteLLMBase):
|
class LiteLLM_Config(LiteLLMBase):
|
||||||
|
|
55
litellm/proxy/hooks/cache_control_check.py
Normal file
55
litellm/proxy/hooks/cache_control_check.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
# What this does?
|
||||||
|
## Checks if key is allowed to use the cache controls passed in to the completion() call
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
|
class CacheControlCheck(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self.print_verbose(f"Inside Cache Control Check Pre-Call Hook")
|
||||||
|
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
|
||||||
|
|
||||||
|
if (allowed_cache_controls is None) or (
|
||||||
|
len(allowed_cache_controls) == 0
|
||||||
|
): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
|
||||||
|
return
|
||||||
|
|
||||||
|
if data.get("cache", None) is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_args = data.get("cache", None)
|
||||||
|
if isinstance(cache_args, dict):
|
||||||
|
for k, v in cache_args.items():
|
||||||
|
if k not in allowed_cache_controls:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
|
||||||
|
)
|
||||||
|
else: # invalid cache
|
||||||
|
return
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
|
@ -1266,6 +1266,7 @@ async def generate_key_helper_fn(
|
||||||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||||
update_key_values: Optional[dict] = None,
|
update_key_values: Optional[dict] = None,
|
||||||
key_alias: Optional[str] = None,
|
key_alias: Optional[str] = None,
|
||||||
|
allowed_cache_controls: Optional[list] = [],
|
||||||
):
|
):
|
||||||
global prisma_client, custom_db_client
|
global prisma_client, custom_db_client
|
||||||
|
|
||||||
|
@ -1320,6 +1321,7 @@ async def generate_key_helper_fn(
|
||||||
user_id = user_id or str(uuid.uuid4())
|
user_id = user_id or str(uuid.uuid4())
|
||||||
tpm_limit = tpm_limit
|
tpm_limit = tpm_limit
|
||||||
rpm_limit = rpm_limit
|
rpm_limit = rpm_limit
|
||||||
|
allowed_cache_controls = allowed_cache_controls
|
||||||
if type(team_id) is not str:
|
if type(team_id) is not str:
|
||||||
team_id = str(team_id)
|
team_id = str(team_id)
|
||||||
try:
|
try:
|
||||||
|
@ -1336,6 +1338,7 @@ async def generate_key_helper_fn(
|
||||||
"rpm_limit": rpm_limit,
|
"rpm_limit": rpm_limit,
|
||||||
"budget_duration": budget_duration,
|
"budget_duration": budget_duration,
|
||||||
"budget_reset_at": reset_at,
|
"budget_reset_at": reset_at,
|
||||||
|
"allowed_cache_controls": allowed_cache_controls,
|
||||||
}
|
}
|
||||||
key_data = {
|
key_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -1354,6 +1357,7 @@ async def generate_key_helper_fn(
|
||||||
"rpm_limit": rpm_limit,
|
"rpm_limit": rpm_limit,
|
||||||
"budget_duration": key_budget_duration,
|
"budget_duration": key_budget_duration,
|
||||||
"budget_reset_at": key_reset_at,
|
"budget_reset_at": key_reset_at,
|
||||||
|
"allowed_cache_controls": allowed_cache_controls,
|
||||||
}
|
}
|
||||||
if general_settings.get("allow_user_auth", False) == True:
|
if general_settings.get("allow_user_auth", False) == True:
|
||||||
key_data["key_name"] = f"sk-...{token[-4:]}"
|
key_data["key_name"] = f"sk-...{token[-4:]}"
|
||||||
|
|
|
@ -20,6 +20,7 @@ model LiteLLM_UserTable {
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
budget_duration String?
|
budget_duration String?
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate Tokens for Proxy
|
// Generate Tokens for Proxy
|
||||||
|
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
budget_duration String?
|
budget_duration String?
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
}
|
}
|
||||||
|
|
||||||
// store proxy config.yaml
|
// store proxy config.yaml
|
||||||
|
|
|
@ -10,6 +10,7 @@ from litellm.proxy._types import (
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||||
|
from litellm.proxy.hooks.cache_control_check import CacheControlCheck
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -42,6 +43,7 @@ class ProxyLogging:
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||||
self.max_budget_limiter = MaxBudgetLimiter()
|
self.max_budget_limiter = MaxBudgetLimiter()
|
||||||
|
self.cache_control_check = CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
pass
|
pass
|
||||||
|
@ -57,6 +59,7 @@ class ProxyLogging:
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
litellm.callbacks.append(self.max_budget_limiter)
|
litellm.callbacks.append(self.max_budget_limiter)
|
||||||
|
litellm.callbacks.append(self.cache_control_check)
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
litellm.input_callback.append(callback)
|
litellm.input_callback.append(callback)
|
||||||
|
|
|
@ -2217,7 +2217,7 @@ def client(original_function):
|
||||||
litellm.cache is not None
|
litellm.cache is not None
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
):
|
) and (kwargs.get("cache", {}).get("no-store", False) != True):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
|
@ -2430,9 +2430,12 @@ def client(original_function):
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if (
|
if (
|
||||||
litellm.cache is not None
|
(litellm.cache is not None)
|
||||||
and str(original_function.__name__)
|
and (
|
||||||
in litellm.cache.supported_call_types
|
str(original_function.__name__)
|
||||||
|
in litellm.cache.supported_call_types
|
||||||
|
)
|
||||||
|
and (kwargs.get("cache", {}).get("no-store", False) != True)
|
||||||
):
|
):
|
||||||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||||
result, litellm.EmbeddingResponse
|
result, litellm.EmbeddingResponse
|
||||||
|
|
|
@ -20,6 +20,7 @@ model LiteLLM_UserTable {
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
budget_duration String?
|
budget_duration String?
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate Tokens for Proxy
|
// Generate Tokens for Proxy
|
||||||
|
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
budget_duration String?
|
budget_duration String?
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
}
|
}
|
||||||
|
|
||||||
// store proxy config.yaml
|
// store proxy config.yaml
|
||||||
|
|
|
@ -351,21 +351,10 @@ async def test_key_info_spend_values_sagemaker():
|
||||||
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||||
session=session, key=new_key, model="sagemaker-completion-model"
|
session=session, key=new_key, model="sagemaker-completion-model"
|
||||||
)
|
)
|
||||||
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
|
||||||
# prompt_cost, completion_cost = litellm.cost_per_token(
|
|
||||||
# model="azure/gpt-35-turbo",
|
|
||||||
# prompt_tokens=prompt_tokens,
|
|
||||||
# completion_tokens=completion_tokens,
|
|
||||||
# )
|
|
||||||
# response_cost = prompt_cost + completion_cost
|
|
||||||
await asyncio.sleep(5) # allow db log to be updated
|
await asyncio.sleep(5) # allow db log to be updated
|
||||||
key_info = await get_key_info(
|
key_info = await get_key_info(
|
||||||
session=session, get_key=new_key, call_key=new_key
|
session=session, get_key=new_key, call_key=new_key
|
||||||
)
|
)
|
||||||
# print(
|
|
||||||
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
|
||||||
# )
|
|
||||||
# rounded_response_cost = round(response_cost, 8)
|
|
||||||
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
||||||
assert rounded_key_info_spend > 0
|
assert rounded_key_info_spend > 0
|
||||||
# assert rounded_response_cost == rounded_key_info_spend
|
# assert rounded_response_cost == rounded_key_info_spend
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue