From cd8d35107bd2ff10eeefa9c061a11ca93cc3cc18 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 15 Feb 2024 20:16:15 -0800 Subject: [PATCH] fix: check key permissions for turning on/off pii masking --- litellm/proxy/_types.py | 2 +- litellm/proxy/hooks/cache_control_check.py | 14 +++++++------ litellm/proxy/hooks/presidio_pii_masking.py | 22 +++++++++++++++------ schema.prisma | 2 +- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e2f3e696f6..3f8f1944ed 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -382,7 +382,7 @@ class LiteLLM_VerificationToken(LiteLLMBase): budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None allowed_cache_controls: Optional[list] = [] - permissions: Optional[dict] = None + permissions: Dict = {} class UserAPIKeyAuth( diff --git a/litellm/proxy/hooks/cache_control_check.py b/litellm/proxy/hooks/cache_control_check.py index c50c4ec1fc..3160fe97ad 100644 --- a/litellm/proxy/hooks/cache_control_check.py +++ b/litellm/proxy/hooks/cache_control_check.py @@ -30,18 +30,20 @@ class _PROXY_CacheControlCheck(CustomLogger): self.print_verbose(f"Inside Cache Control Check Pre-Call Hook") allowed_cache_controls = user_api_key_dict.allowed_cache_controls - if (allowed_cache_controls is None) or ( - len(allowed_cache_controls) == 0 - ): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663 - return - if data.get("cache", None) is None: return cache_args = data.get("cache", None) if isinstance(cache_args, dict): for k, v in cache_args.items(): - if k not in allowed_cache_controls: + if ( + (allowed_cache_controls is not None) + and (isinstance(allowed_cache_controls, list)) + and ( + len(allowed_cache_controls) > 0 + ) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663 + and k not in allowed_cache_controls + ): raise HTTPException( status_code=403, detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.", diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index 85e6260745..5152046bc5 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -61,7 +61,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): except: pass - async def check_pii(self, text: str) -> str: + async def check_pii(self, text: str, output_parse_pii: bool) -> str: """ [TODO] make this more performant for high-throughput scenario """ @@ -92,10 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): start = item["start"] end = item["end"] replacement = item["text"] # replacement token - if ( - item["operator"] == "replace" - and litellm.output_parse_pii == True - ): + if item["operator"] == "replace" and output_parse_pii == True: # check if token in dict # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing if replacement in self.pii_tokens: @@ -125,13 +122,26 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): For multiple messages in /chat/completions, we'll need to call them in parallel. """ + permissions = user_api_key_dict.permissions + + if permissions.get("pii", True) == False: # allow key to turn off pii masking + return data + + output_parse_pii = permissions.get( + "output_parse_pii", litellm.output_parse_pii + ) # allow key to turn on/off output parsing for pii + if call_type == "completion": # /chat/completions requests messages = data["messages"] tasks = [] for m in messages: if isinstance(m["content"], str): - tasks.append(self.check_pii(text=m["content"])) + tasks.append( + self.check_pii( + text=m["content"], output_parse_pii=output_parse_pii + ) + ) responses = await asyncio.gather(*tasks) for index, r in enumerate(responses): if isinstance(messages[index]["content"], str): diff --git a/schema.prisma b/schema.prisma index dd473bb69e..a047951dcc 100644 --- a/schema.prisma +++ b/schema.prisma @@ -55,6 +55,7 @@ model LiteLLM_VerificationToken { config Json @default("{}") user_id String? team_id String? + permissions Json @default("{}") max_parallel_requests Int? metadata Json @default("{}") tpm_limit BigInt? @@ -63,7 +64,6 @@ model LiteLLM_VerificationToken { budget_duration String? budget_reset_at DateTime? allowed_cache_controls String[] @default([]) - permissions Json? } // store proxy config.yaml