feat(llm_guard.py): enable key-specific llm guard check

This commit is contained in:
Krrish Dholakia 2024-03-26 17:21:51 -07:00
parent bec093675c
commit e10eb8f6fe
8 changed files with 26 additions and 3 deletions

View file

@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""

View file

@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""

View file

@ -30,9 +30,12 @@ litellm.set_verbose = True
class _ENTERPRISE_LLMGuard(CustomLogger):
# Class variables or attributes
def __init__(
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None
self,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
):
self.mock_redacted_text = mock_redacted_text
self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only
return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
@ -95,6 +98,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""
@ -104,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
- LLM Guard can handle things like PII Masking, etc.
"""
self.print_verbose(f"Inside LLM Guard Pre-Call Hook")
# check if llm guard enabled for specific keys only
if self.llm_guard_mode == "key-specific":
if (
user_api_key_dict.permissions.get("enable_llm_guard_check", False)
== False
):
return
try:
assert call_type in [
"completion",