mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(llama_guard.py): add llama guard support for content moderation + new async_moderation_hook
endpoint
This commit is contained in:
parent
5e7dda4f88
commit
2a4a6995ac
12 changed files with 163 additions and 132 deletions
|
@ -1368,7 +1368,7 @@ class ProxyConfig:
|
|||
)
|
||||
elif key == "callbacks":
|
||||
if isinstance(value, list):
|
||||
imported_list = []
|
||||
imported_list: List[Any] = []
|
||||
for callback in value: # ["presidio", <my-custom-callback>]
|
||||
if isinstance(callback, str) and callback == "presidio":
|
||||
from litellm.proxy.hooks.presidio_pii_masking import (
|
||||
|
@ -1377,6 +1377,16 @@ class ProxyConfig:
|
|||
|
||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
|
||||
imported_list.append(pii_masking_object)
|
||||
elif (
|
||||
isinstance(callback, str)
|
||||
and callback == "llamaguard_moderations"
|
||||
):
|
||||
from litellm.proxy.enterprise.hooks.llama_guard import (
|
||||
_ENTERPRISE_LlamaGuard,
|
||||
)
|
||||
|
||||
llama_guard_object = _ENTERPRISE_LlamaGuard()
|
||||
imported_list.append(llama_guard_object)
|
||||
else:
|
||||
imported_list.append(
|
||||
get_instance_fn(
|
||||
|
@ -2423,6 +2433,9 @@ async def chat_completion(
|
|||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||
)
|
||||
|
||||
tasks = []
|
||||
tasks.append(proxy_logging_obj.during_call_hook(data=data))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
|
@ -2433,34 +2446,40 @@ async def chat_completion(
|
|||
)
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.acompletion(**data)
|
||||
tasks.append(litellm.acompletion(**data))
|
||||
elif "user_config" in data:
|
||||
# initialize a new router instance. make request using this Router
|
||||
router_config = data.pop("user_config")
|
||||
user_router = litellm.Router(**router_config)
|
||||
response = await user_router.acompletion(**data)
|
||||
tasks.append(user_router.acompletion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.acompletion(**data)
|
||||
tasks.append(llm_router.acompletion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.acompletion(**data)
|
||||
tasks.append(llm_router.acompletion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.acompletion(**data, specific_deployment=True)
|
||||
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.acompletion(**data)
|
||||
tasks.append(litellm.acompletion(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
# wait for call to end
|
||||
responses = await asyncio.gather(
|
||||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
response = responses[1]
|
||||
|
||||
# Post Call Processing
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
if hasattr(response, "_hidden_params"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue