fix linting errors on main

This commit is contained in:
Ishaan Jaff 2024-07-18 13:32:48 -07:00
parent df4aab8be9
commit 75ca53fab5
3 changed files with 64 additions and 29 deletions

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Literal from typing import Literal, List, Dict
import litellm, sys import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -18,7 +18,7 @@ from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import Role from litellm.types.guardrails import Role, GuardrailItem, default_roles
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
@ -33,9 +33,10 @@ GUARDRAIL_NAME = "lakera_prompt_injection"
INPUT_POSITIONING_MAP = { INPUT_POSITIONING_MAP = {
Role.SYSTEM.value: 0, Role.SYSTEM.value: 0,
Role.USER.value: 1, Role.USER.value: 1,
Role.ASSISTANT.value: 2 Role.ASSISTANT.value: 2,
} }
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
def __init__(self): def __init__(self):
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
@ -63,20 +64,32 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
return return
text = "" text = ""
if "messages" in data and isinstance(data["messages"], list): if "messages" in data and isinstance(data["messages"], list):
enabled_roles = litellm.guardrail_name_config_map["prompt_injection"].enabled_roles enabled_roles = litellm.guardrail_name_config_map[
lakera_input_dict = {role: None for role in INPUT_POSITIONING_MAP.keys()} "prompt_injection"
].enabled_roles
if enabled_roles is None:
enabled_roles = default_roles
lakera_input_dict: Dict = {
role: None for role in INPUT_POSITIONING_MAP.keys()
}
system_message = None system_message = None
tool_call_messages = [] tool_call_messages: List = []
for message in data["messages"]: for message in data["messages"]:
role = message.get("role") role = message.get("role")
if role in enabled_roles: if role in enabled_roles:
if "tool_calls" in message: if "tool_calls" in message:
tool_call_messages = [*tool_call_messages, *message["tool_calls"]] tool_call_messages = [
*tool_call_messages,
*message["tool_calls"],
]
if role == Role.SYSTEM.value: # we need this for later if role == Role.SYSTEM.value: # we need this for later
system_message = message system_message = message
continue continue
lakera_input_dict[role] = {"role": role, "content": message.get('content')} lakera_input_dict[role] = {
"role": role,
"content": message.get("content"),
}
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here. # For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already) # Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
@ -91,13 +104,23 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
function_input.append(tool_call["function"]["arguments"]) function_input.append(tool_call["function"]["arguments"])
if len(function_input) > 0: if len(function_input) > 0:
content += " Function Input: " + ' '.join(function_input) content += " Function Input: " + " ".join(function_input)
lakera_input_dict[Role.SYSTEM.value] = {'role': Role.SYSTEM.value, 'content': content} lakera_input_dict[Role.SYSTEM.value] = {
"role": Role.SYSTEM.value,
"content": content,
}
lakera_input = [
lakera_input = [v for k, v in sorted(lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]) if v is not None] v
for k, v in sorted(
lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]
)
if v is not None
]
if len(lakera_input) == 0: if len(lakera_input) == 0:
verbose_proxy_logger.debug("Skipping lakera prompt injection, no roles with messages found") verbose_proxy_logger.debug(
"Skipping lakera prompt injection, no roles with messages found"
)
return return
elif "input" in data and isinstance(data["input"], str): elif "input" in data and isinstance(data["input"], str):

View file

@ -1,14 +1,13 @@
# This file runs a health check for the LLM, used on litellm/proxy # This file runs a health check for the LLM, used on litellm/proxy
import asyncio import asyncio
import logging
import random import random
from typing import Optional from typing import Optional
import litellm import litellm
import logging
from litellm._logging import print_verbose from litellm._logging import print_verbose
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,6 +15,7 @@ ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]
MINIMAL_DISPLAY_PARAMS = ["model"] MINIMAL_DISPLAY_PARAMS = ["model"]
def _get_random_llm_message(): def _get_random_llm_message():
""" """
Get a random message from the LLM. Get a random message from the LLM.
@ -25,7 +25,7 @@ def _get_random_llm_message():
return [{"role": "user", "content": random.choice(messages)}] return [{"role": "user", "content": random.choice(messages)}]
def _clean_endpoint_data(endpoint_data: dict, details: bool): def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
""" """
Clean the endpoint data for display to users. Clean the endpoint data for display to users.
""" """
@ -36,7 +36,7 @@ def _clean_endpoint_data(endpoint_data: dict, details: bool):
) )
async def _perform_health_check(model_list: list, details: bool): async def _perform_health_check(model_list: list, details: Optional[bool] = True):
""" """
Perform a health check for each model in the list. Perform a health check for each model in the list.
""" """
@ -64,9 +64,13 @@ async def _perform_health_check(model_list: list, details: bool):
litellm_params = model["litellm_params"] litellm_params = model["litellm_params"]
if isinstance(is_healthy, dict) and "error" not in is_healthy: if isinstance(is_healthy, dict) and "error" not in is_healthy:
healthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details)) healthy_endpoints.append(
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
)
elif isinstance(is_healthy, dict): elif isinstance(is_healthy, dict):
unhealthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details)) unhealthy_endpoints.append(
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
)
else: else:
unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
@ -74,7 +78,10 @@ async def _perform_health_check(model_list: list, details: bool):
async def perform_health_check( async def perform_health_check(
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None, details: Optional[bool] = True model_list: list,
model: Optional[str] = None,
cli_model: Optional[str] = None,
details: Optional[bool] = True,
): ):
""" """
Perform a health check on the system. Perform a health check on the system.
@ -98,6 +105,8 @@ async def perform_health_check(
_new_model_list = [x for x in model_list if x["model_name"] == model] _new_model_list = [x for x in model_list if x["model_name"] == model]
model_list = _new_model_list model_list = _new_model_list
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list, details) healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
model_list, details
)
return healthy_endpoints, unhealthy_endpoints return healthy_endpoints, unhealthy_endpoints

View file

@ -18,12 +18,15 @@ litellm_settings:
default_on: true default_on: true
""" """
class Role(Enum): class Role(Enum):
SYSTEM = "system" SYSTEM = "system"
ASSISTANT = "assistant" ASSISTANT = "assistant"
USER = "user" USER = "user"
default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER];
default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER]
class GuardrailItemSpec(TypedDict, total=False): class GuardrailItemSpec(TypedDict, total=False):
callbacks: Required[List[str]] callbacks: Required[List[str]]
@ -37,7 +40,7 @@ class GuardrailItem(BaseModel):
default_on: bool default_on: bool
logging_only: Optional[bool] logging_only: Optional[bool]
guardrail_name: str guardrail_name: str
enabled_roles: List[Role] enabled_roles: Optional[List[Role]]
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
def __init__( def __init__(
@ -46,7 +49,7 @@ class GuardrailItem(BaseModel):
guardrail_name: str, guardrail_name: str,
default_on: bool = False, default_on: bool = False,
logging_only: Optional[bool] = None, logging_only: Optional[bool] = None,
enabled_roles: List[Role] = default_roles, enabled_roles: Optional[List[Role]] = default_roles,
): ):
super().__init__( super().__init__(
callbacks=callbacks, callbacks=callbacks,