forked from phoenix/litellm-mirror
feat(proxy_server.py): enable cache controls per key + no-store cache flag
This commit is contained in:
parent
37de964da4
commit
f9acad87dc
8 changed files with 108 additions and 42 deletions
|
@ -5,6 +5,15 @@ from datetime import datetime
|
|||
import uuid, json, sys, os
|
||||
|
||||
|
||||
def hash_token(token: str):
|
||||
import hashlib
|
||||
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
return hashed_token
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
|
@ -137,6 +146,7 @@ class GenerateRequestBase(LiteLLMBase):
|
|||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
allowed_cache_controls: Optional[list] = []
|
||||
|
||||
|
||||
class GenerateKeyRequest(GenerateRequestBase):
|
||||
|
@ -177,25 +187,6 @@ class UpdateKeyRequest(GenerateKeyRequest):
|
|||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: Optional[float] = 0
|
||||
max_budget: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
metadata: dict = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List
|
||||
|
||||
|
@ -320,22 +311,39 @@ class ConfigYAML(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLM_VerificationToken(LiteLLMBase):
|
||||
token: str
|
||||
token: Optional[str] = None
|
||||
key_name: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
max_budget: Optional[float] = None
|
||||
expires: Union[str, None]
|
||||
models: List[str]
|
||||
aliases: Dict[str, str] = {}
|
||||
config: Dict[str, str] = {}
|
||||
user_id: Union[str, None]
|
||||
max_parallel_requests: Union[int, None]
|
||||
metadata: Dict[str, str] = {}
|
||||
expires: Optional[str] = None
|
||||
models: List = []
|
||||
aliases: Dict = {}
|
||||
config: Dict = {}
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Dict = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
allowed_cache_controls: Optional[list] = []
|
||||
|
||||
|
||||
class UserAPIKeyAuth(
|
||||
LiteLLM_VerificationToken
|
||||
): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_api_key(cls, values):
|
||||
if values.get("api_key") is not None:
|
||||
values.update({"token": hash_token(values.get("api_key"))})
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMBase):
|
||||
|
|
55
litellm/proxy/hooks/cache_control_check.py
Normal file
55
litellm/proxy/hooks/cache_control_check.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# What this does?
|
||||
## Checks if key is allowed to use the cache controls passed in to the completion() call
|
||||
|
||||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
|
||||
|
||||
class CacheControlCheck(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
self.print_verbose(f"Inside Cache Control Check Pre-Call Hook")
|
||||
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
|
||||
|
||||
if (allowed_cache_controls is None) or (
|
||||
len(allowed_cache_controls) == 0
|
||||
): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
|
||||
return
|
||||
|
||||
if data.get("cache", None) is None:
|
||||
return
|
||||
|
||||
cache_args = data.get("cache", None)
|
||||
if isinstance(cache_args, dict):
|
||||
for k, v in cache_args.items():
|
||||
if k not in allowed_cache_controls:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
|
||||
)
|
||||
else: # invalid cache
|
||||
return
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
|
@ -1266,6 +1266,7 @@ async def generate_key_helper_fn(
|
|||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||
update_key_values: Optional[dict] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
allowed_cache_controls: Optional[list] = [],
|
||||
):
|
||||
global prisma_client, custom_db_client
|
||||
|
||||
|
@ -1320,6 +1321,7 @@ async def generate_key_helper_fn(
|
|||
user_id = user_id or str(uuid.uuid4())
|
||||
tpm_limit = tpm_limit
|
||||
rpm_limit = rpm_limit
|
||||
allowed_cache_controls = allowed_cache_controls
|
||||
if type(team_id) is not str:
|
||||
team_id = str(team_id)
|
||||
try:
|
||||
|
@ -1336,6 +1338,7 @@ async def generate_key_helper_fn(
|
|||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": budget_duration,
|
||||
"budget_reset_at": reset_at,
|
||||
"allowed_cache_controls": allowed_cache_controls,
|
||||
}
|
||||
key_data = {
|
||||
"token": token,
|
||||
|
@ -1354,6 +1357,7 @@ async def generate_key_helper_fn(
|
|||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": key_budget_duration,
|
||||
"budget_reset_at": key_reset_at,
|
||||
"allowed_cache_controls": allowed_cache_controls,
|
||||
}
|
||||
if general_settings.get("allow_user_auth", False) == True:
|
||||
key_data["key_name"] = f"sk-...{token[-4:]}"
|
||||
|
|
|
@ -20,6 +20,7 @@ model LiteLLM_UserTable {
|
|||
rpm_limit BigInt?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
allowed_cache_controls String[] @default([])
|
||||
}
|
||||
|
||||
// Generate Tokens for Proxy
|
||||
|
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
|
|||
max_budget Float?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
allowed_cache_controls String[] @default([])
|
||||
}
|
||||
|
||||
// store proxy config.yaml
|
||||
|
|
|
@ -10,6 +10,7 @@ from litellm.proxy._types import (
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
from litellm.proxy.hooks.cache_control_check import CacheControlCheck
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -42,6 +43,7 @@ class ProxyLogging:
|
|||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
self.cache_control_check = CacheControlCheck()
|
||||
self.alerting: Optional[List] = None
|
||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||
pass
|
||||
|
@ -57,6 +59,7 @@ class ProxyLogging:
|
|||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
litellm.callbacks.append(self.cache_control_check)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
|
|
|
@ -2217,7 +2217,7 @@ def client(original_function):
|
|||
litellm.cache is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
) and (kwargs.get("cache", {}).get("no-store", False) != True):
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
|
@ -2430,9 +2430,12 @@ def client(original_function):
|
|||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
(litellm.cache is not None)
|
||||
and (
|
||||
str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
)
|
||||
and (kwargs.get("cache", {}).get("no-store", False) != True)
|
||||
):
|
||||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||
result, litellm.EmbeddingResponse
|
||||
|
|
|
@ -20,6 +20,7 @@ model LiteLLM_UserTable {
|
|||
rpm_limit BigInt?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
allowed_cache_controls String[] @default([])
|
||||
}
|
||||
|
||||
// Generate Tokens for Proxy
|
||||
|
@ -41,6 +42,7 @@ model LiteLLM_VerificationToken {
|
|||
max_budget Float?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
allowed_cache_controls String[] @default([])
|
||||
}
|
||||
|
||||
// store proxy config.yaml
|
||||
|
|
|
@ -351,21 +351,10 @@ async def test_key_info_spend_values_sagemaker():
|
|||
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||
session=session, key=new_key, model="sagemaker-completion-model"
|
||||
)
|
||||
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
# prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
# model="azure/gpt-35-turbo",
|
||||
# prompt_tokens=prompt_tokens,
|
||||
# completion_tokens=completion_tokens,
|
||||
# )
|
||||
# response_cost = prompt_cost + completion_cost
|
||||
await asyncio.sleep(5) # allow db log to be updated
|
||||
key_info = await get_key_info(
|
||||
session=session, get_key=new_key, call_key=new_key
|
||||
)
|
||||
# print(
|
||||
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||
# )
|
||||
# rounded_response_cost = round(response_cost, 8)
|
||||
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
||||
assert rounded_key_info_spend > 0
|
||||
# assert rounded_response_cost == rounded_key_info_spend
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue