fix: check key permissions for turning on/off pii masking

This commit is contained in:
Krrish Dholakia 2024-02-15 20:16:15 -08:00
parent cccd577e75
commit cd8d35107b
4 changed files with 26 additions and 14 deletions

View file

@ -382,7 +382,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
permissions: Optional[dict] = None
permissions: Dict = {}
class UserAPIKeyAuth(

View file

@ -30,18 +30,20 @@ class _PROXY_CacheControlCheck(CustomLogger):
self.print_verbose(f"Inside Cache Control Check Pre-Call Hook")
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
if (allowed_cache_controls is None) or (
len(allowed_cache_controls) == 0
): # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
return
if data.get("cache", None) is None:
return
cache_args = data.get("cache", None)
if isinstance(cache_args, dict):
for k, v in cache_args.items():
if k not in allowed_cache_controls:
if (
(allowed_cache_controls is not None)
and (isinstance(allowed_cache_controls, list))
and (
len(allowed_cache_controls) > 0
) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
and k not in allowed_cache_controls
):
raise HTTPException(
status_code=403,
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",

View file

@ -61,7 +61,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
except:
pass
async def check_pii(self, text: str) -> str:
async def check_pii(self, text: str, output_parse_pii: bool) -> str:
"""
[TODO] make this more performant for high-throughput scenario
"""
@ -92,10 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if (
item["operator"] == "replace"
and litellm.output_parse_pii == True
):
if item["operator"] == "replace" and output_parse_pii == True:
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
@ -125,13 +122,26 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
For multiple messages in /chat/completions, we'll need to call them in parallel.
"""
permissions = user_api_key_dict.permissions
if permissions.get("pii", True) == False: # allow key to turn off pii masking
return data
output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii
) # allow key to turn on/off output parsing for pii
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(self.check_pii(text=m["content"]))
tasks.append(
self.check_pii(
text=m["content"], output_parse_pii=output_parse_pii
)
)
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):

View file

@ -55,6 +55,7 @@ model LiteLLM_VerificationToken {
config Json @default("{}")
user_id String?
team_id String?
permissions Json @default("{}")
max_parallel_requests Int?
metadata Json @default("{}")
tpm_limit BigInt?
@ -63,7 +64,6 @@ model LiteLLM_VerificationToken {
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
permissions Json?
}
// store proxy config.yaml