feat(llama_guard.py): add llama guard support for content moderation + new async_moderation_hook endpoint

This commit is contained in:
Krrish Dholakia 2024-02-16 18:45:25 -08:00
parent 5e7dda4f88
commit 2a4a6995ac
12 changed files with 163 additions and 132 deletions

View file

@ -1368,7 +1368,7 @@ class ProxyConfig:
)
elif key == "callbacks":
if isinstance(value, list):
imported_list = []
imported_list: List[Any] = []
for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
@ -1377,6 +1377,16 @@ class ProxyConfig:
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
imported_list.append(pii_masking_object)
elif (
isinstance(callback, str)
and callback == "llamaguard_moderations"
):
from litellm.proxy.enterprise.hooks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
else:
imported_list.append(
get_instance_fn(
@ -2423,6 +2433,9 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
tasks = []
tasks.append(proxy_logging_obj.during_call_hook(data=data))
start_time = time.time()
### ROUTE THE REQUEST ###
@ -2433,34 +2446,40 @@ async def chat_completion(
)
# skip router if user passed their key
if "api_key" in data:
response = await litellm.acompletion(**data)
tasks.append(litellm.acompletion(**data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
response = await user_router.acompletion(**data)
tasks.append(user_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.acompletion(**data)
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.acompletion(**data)
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment=True)
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.acompletion(**data)
tasks.append(litellm.acompletion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
# wait for call to end
responses = await asyncio.gather(
*tasks
) # run the moderation check in parallel to the actual llm api call
response = responses[1]
# Post Call Processing
data["litellm_status"] = "success" # used for alerting
if hasattr(response, "_hidden_params"):