diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 22565eb2b7..4c88c0be97 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -5,6 +5,15 @@ from datetime import datetime import uuid, json, sys, os +def hash_token(token: str): + import hashlib + + # Hash the string using SHA-256 + hashed_token = hashlib.sha256(token.encode()).hexdigest() + + return hashed_token + + class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. @@ -137,6 +146,7 @@ class GenerateRequestBase(LiteLLMBase): tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None budget_duration: Optional[str] = None + allowed_cache_controls: Optional[list] = [] class GenerateKeyRequest(GenerateRequestBase): @@ -177,25 +187,6 @@ class UpdateKeyRequest(GenerateKeyRequest): metadata: Optional[dict] = None -class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth - """ - Return the row in the db - """ - - api_key: Optional[str] = None - models: list = [] - aliases: dict = {} - config: dict = {} - spend: Optional[float] = 0 - max_budget: Optional[float] = None - user_id: Optional[str] = None - max_parallel_requests: Optional[int] = None - duration: str = "1h" - metadata: dict = {} - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None - - class DeleteKeyRequest(LiteLLMBase): keys: List @@ -320,22 +311,39 @@ class ConfigYAML(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase): - token: str + token: Optional[str] = None key_name: Optional[str] = None key_alias: Optional[str] = None spend: float = 0.0 max_budget: Optional[float] = None - expires: Union[str, None] - models: List[str] - aliases: Dict[str, str] = {} - config: Dict[str, str] = {} - user_id: Union[str, None] - max_parallel_requests: Union[int, None] - metadata: Dict[str, str] = {} + expires: Optional[str] = None + models: List = [] + aliases: Dict = {} + config: Dict = {} + user_id: Optional[str] = None + max_parallel_requests: Optional[int] = None + metadata: Dict = {} tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None + allowed_cache_controls: Optional[list] = [] + + +class UserAPIKeyAuth( + LiteLLM_VerificationToken +): # the expected response object for user api key auth + """ + Return the row in the db + """ + + api_key: Optional[str] = None + + @root_validator(pre=True) + def check_api_key(cls, values): + if values.get("api_key") is not None: + values.update({"token": hash_token(values.get("api_key"))}) + return values class LiteLLM_Config(LiteLLMBase): diff --git a/litellm/proxy/hooks/cache_control_check.py b/litellm/proxy/hooks/cache_control_check.py new file mode 100644 index 0000000000..670e7554d6 --- /dev/null +++ b/litellm/proxy/hooks/cache_control_check.py @@ -0,0 +1,55 @@ +# What this does? +## Checks if key is allowed to use the cache controls passed in to the completion() call + +from typing import Optional +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +import json, traceback + + +class CacheControlCheck(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + def print_verbose(self, print_statement): + if litellm.set_verbose is True: + print(print_statement) # noqa + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + self.print_verbose(f"Inside Cache Control Check Pre-Call Hook") + allowed_cache_controls = user_api_key_dict.allowed_cache_controls + + if (allowed_cache_controls is None) or ( + len(allowed_cache_controls) == 0 + ): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663 + return + + if data.get("cache", None) is None: + return + + cache_args = data.get("cache", None) + if isinstance(cache_args, dict): + for k, v in cache_args.items(): + if k not in allowed_cache_controls: + raise HTTPException( + status_code=403, + detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.", + ) + else: # invalid cache + return + + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f1ec2744cd..6109b824d7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1266,6 +1266,7 @@ async def generate_key_helper_fn( query_type: Literal["insert_data", "update_data"] = "insert_data", update_key_values: Optional[dict] = None, key_alias: Optional[str] = None, + allowed_cache_controls: Optional[list] = [], ): global prisma_client, custom_db_client @@ -1320,6 +1321,7 @@ async def generate_key_helper_fn( user_id = user_id or str(uuid.uuid4()) tpm_limit = tpm_limit rpm_limit = rpm_limit + allowed_cache_controls = allowed_cache_controls if type(team_id) is not str: team_id = str(team_id) try: @@ -1336,6 +1338,7 @@ async def generate_key_helper_fn( "rpm_limit": rpm_limit, "budget_duration": budget_duration, "budget_reset_at": reset_at, + "allowed_cache_controls": allowed_cache_controls, } key_data = { "token": token, @@ -1354,6 +1357,7 @@ async def generate_key_helper_fn( "rpm_limit": rpm_limit, "budget_duration": key_budget_duration, "budget_reset_at": key_reset_at, + "allowed_cache_controls": allowed_cache_controls, } if general_settings.get("allow_user_auth", False) == True: key_data["key_name"] = f"sk-...{token[-4:]}" diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 02e4114e5d..da2857075f 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -20,6 +20,7 @@ model LiteLLM_UserTable { rpm_limit BigInt? budget_duration String? budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) } // Generate Tokens for Proxy @@ -41,6 +42,7 @@ model LiteLLM_VerificationToken { max_budget Float? budget_duration String? budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) } // store proxy config.yaml diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 3ec45203f5..9983150d97 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -10,6 +10,7 @@ from litellm.proxy._types import ( from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter +from litellm.proxy.hooks.cache_control_check import CacheControlCheck from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB from litellm._logging import verbose_proxy_logger @@ -42,6 +43,7 @@ class ProxyLogging: self.call_details["user_api_key_cache"] = user_api_key_cache self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_budget_limiter = MaxBudgetLimiter() + self.cache_control_check = CacheControlCheck() self.alerting: Optional[List] = None self.alerting_threshold: float = 300 # default to 5 min. threshold pass @@ -57,6 +59,7 @@ class ProxyLogging: print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_budget_limiter) + litellm.callbacks.append(self.cache_control_check) for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) diff --git a/litellm/utils.py b/litellm/utils.py index 3aaf535149..d91d262c5b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2217,7 +2217,7 @@ def client(original_function): litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types - ): + ) and (kwargs.get("cache", {}).get("no-store", False) != True): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated @@ -2430,9 +2430,12 @@ def client(original_function): # [OPTIONAL] ADD TO CACHE if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types + (litellm.cache is not None) + and ( + str(original_function.__name__) + in litellm.cache.supported_call_types + ) + and (kwargs.get("cache", {}).get("no-store", False) != True) ): if isinstance(result, litellm.ModelResponse) or isinstance( result, litellm.EmbeddingResponse diff --git a/schema.prisma b/schema.prisma index 02e4114e5d..da2857075f 100644 --- a/schema.prisma +++ b/schema.prisma @@ -20,6 +20,7 @@ model LiteLLM_UserTable { rpm_limit BigInt? budget_duration String? budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) } // Generate Tokens for Proxy @@ -41,6 +42,7 @@ model LiteLLM_VerificationToken { max_budget Float? budget_duration String? budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) } // store proxy config.yaml diff --git a/tests/test_keys.py b/tests/test_keys.py index d4ab826d40..97a309e305 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -351,21 +351,10 @@ async def test_key_info_spend_values_sagemaker(): prompt_tokens, completion_tokens = await chat_completion_streaming( session=session, key=new_key, model="sagemaker-completion-model" ) - # print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") - # prompt_cost, completion_cost = litellm.cost_per_token( - # model="azure/gpt-35-turbo", - # prompt_tokens=prompt_tokens, - # completion_tokens=completion_tokens, - # ) - # response_cost = prompt_cost + completion_cost await asyncio.sleep(5) # allow db log to be updated key_info = await get_key_info( session=session, get_key=new_key, call_key=new_key ) - # print( - # f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}" - # ) - # rounded_response_cost = round(response_cost, 8) rounded_key_info_spend = round(key_info["info"]["spend"], 8) assert rounded_key_info_spend > 0 # assert rounded_response_cost == rounded_key_info_spend