diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 2dc77d65a..72485d589 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -10,7 +10,7 @@ import sys, os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from typing import Literal +from typing import Literal, List, Dict import litellm, sys from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger @@ -18,7 +18,7 @@ from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata -from litellm.types.guardrails import Role +from litellm.types.guardrails import Role, GuardrailItem, default_roles from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler @@ -33,9 +33,10 @@ GUARDRAIL_NAME = "lakera_prompt_injection" INPUT_POSITIONING_MAP = { Role.SYSTEM.value: 0, Role.USER.value: 1, - Role.ASSISTANT.value: 2 + Role.ASSISTANT.value: 2, } + class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): def __init__(self): self.async_handler = AsyncHTTPHandler( @@ -63,41 +64,63 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): return text = "" if "messages" in data and isinstance(data["messages"], list): - enabled_roles = litellm.guardrail_name_config_map["prompt_injection"].enabled_roles - lakera_input_dict = {role: None for role in INPUT_POSITIONING_MAP.keys()} - system_message = None - tool_call_messages = [] + enabled_roles = litellm.guardrail_name_config_map[ + "prompt_injection" + ].enabled_roles + if enabled_roles is None: + enabled_roles = default_roles + lakera_input_dict: Dict = { + role: None for role in INPUT_POSITIONING_MAP.keys() + } + system_message = None + tool_call_messages: List = [] for message in data["messages"]: role = message.get("role") if role in enabled_roles: if "tool_calls" in message: - tool_call_messages = [*tool_call_messages, *message["tool_calls"]] - if role == Role.SYSTEM.value: # we need this for later + tool_call_messages = [ + *tool_call_messages, + *message["tool_calls"], + ] + if role == Role.SYSTEM.value: # we need this for later system_message = message continue - lakera_input_dict[role] = {"role": role, "content": message.get('content')} + lakera_input_dict[role] = { + "role": role, + "content": message.get("content"), + } - # For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here. + # For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here. # Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already) # Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked. # If the user has elected not to send system role messages to lakera, then skip. if system_message is not None: if not litellm.add_function_to_prompt: content = system_message.get("content") - function_input = [] + function_input = [] for tool_call in tool_call_messages: if "function" in tool_call: function_input.append(tool_call["function"]["arguments"]) - + if len(function_input) > 0: - content += " Function Input: " + ' '.join(function_input) - lakera_input_dict[Role.SYSTEM.value] = {'role': Role.SYSTEM.value, 'content': content} + content += " Function Input: " + " ".join(function_input) + lakera_input_dict[Role.SYSTEM.value] = { + "role": Role.SYSTEM.value, + "content": content, + } - - lakera_input = [v for k, v in sorted(lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]) if v is not None] + lakera_input = [ + v + for k, v in sorted( + lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]] + ) + if v is not None + ] if len(lakera_input) == 0: - verbose_proxy_logger.debug("Skipping lakera prompt injection, no roles with messages found") + verbose_proxy_logger.debug( + "Skipping lakera prompt injection, no roles with messages found" + ) return elif "input" in data and isinstance(data["input"], str): diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index aa6205c7c..5713fa782 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -1,14 +1,13 @@ # This file runs a health check for the LLM, used on litellm/proxy import asyncio +import logging import random from typing import Optional import litellm -import logging from litellm._logging import print_verbose - logger = logging.getLogger(__name__) @@ -16,6 +15,7 @@ ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"] MINIMAL_DISPLAY_PARAMS = ["model"] + def _get_random_llm_message(): """ Get a random message from the LLM. @@ -25,7 +25,7 @@ def _get_random_llm_message(): return [{"role": "user", "content": random.choice(messages)}] -def _clean_endpoint_data(endpoint_data: dict, details: bool): +def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True): """ Clean the endpoint data for display to users. """ @@ -36,7 +36,7 @@ def _clean_endpoint_data(endpoint_data: dict, details: bool): ) -async def _perform_health_check(model_list: list, details: bool): +async def _perform_health_check(model_list: list, details: Optional[bool] = True): """ Perform a health check for each model in the list. """ @@ -64,9 +64,13 @@ async def _perform_health_check(model_list: list, details: bool): litellm_params = model["litellm_params"] if isinstance(is_healthy, dict) and "error" not in is_healthy: - healthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details)) + healthy_endpoints.append( + _clean_endpoint_data({**litellm_params, **is_healthy}, details) + ) elif isinstance(is_healthy, dict): - unhealthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details)) + unhealthy_endpoints.append( + _clean_endpoint_data({**litellm_params, **is_healthy}, details) + ) else: unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) @@ -74,7 +78,10 @@ async def _perform_health_check(model_list: list, details: bool): async def perform_health_check( - model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None, details: Optional[bool] = True + model_list: list, + model: Optional[str] = None, + cli_model: Optional[str] = None, + details: Optional[bool] = True, ): """ Perform a health check on the system. @@ -98,6 +105,8 @@ async def perform_health_check( _new_model_list = [x for x in model_list if x["model_name"] == model] model_list = _new_model_list - healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list, details) + healthy_endpoints, unhealthy_endpoints = await _perform_health_check( + model_list, details + ) return healthy_endpoints, unhealthy_endpoints diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 3b6dfba9f..27be12615 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -18,12 +18,15 @@ litellm_settings: default_on: true """ + class Role(Enum): SYSTEM = "system" ASSISTANT = "assistant" USER = "user" -default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER]; + +default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER] + class GuardrailItemSpec(TypedDict, total=False): callbacks: Required[List[str]] @@ -37,7 +40,7 @@ class GuardrailItem(BaseModel): default_on: bool logging_only: Optional[bool] guardrail_name: str - enabled_roles: List[Role] + enabled_roles: Optional[List[Role]] model_config = ConfigDict(use_enum_values=True) def __init__( @@ -46,7 +49,7 @@ class GuardrailItem(BaseModel): guardrail_name: str, default_on: bool = False, logging_only: Optional[bool] = None, - enabled_roles: List[Role] = default_roles, + enabled_roles: Optional[List[Role]] = default_roles, ): super().__init__( callbacks=callbacks,