forked from phoenix/litellm-mirror
fix linting errors on main
This commit is contained in:
parent
df4aab8be9
commit
75ca53fab5
3 changed files with 64 additions and 29 deletions
|
@ -10,7 +10,7 @@ import sys, os
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from typing import Literal
|
from typing import Literal, List, Dict
|
||||||
import litellm, sys
|
import litellm, sys
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -18,7 +18,7 @@ from fastapi import HTTPException
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
from litellm.types.guardrails import Role
|
from litellm.types.guardrails import Role, GuardrailItem, default_roles
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
@ -33,9 +33,10 @@ GUARDRAIL_NAME = "lakera_prompt_injection"
|
||||||
INPUT_POSITIONING_MAP = {
|
INPUT_POSITIONING_MAP = {
|
||||||
Role.SYSTEM.value: 0,
|
Role.SYSTEM.value: 0,
|
||||||
Role.USER.value: 1,
|
Role.USER.value: 1,
|
||||||
Role.ASSISTANT.value: 2
|
Role.ASSISTANT.value: 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
@ -63,41 +64,63 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||||
return
|
return
|
||||||
text = ""
|
text = ""
|
||||||
if "messages" in data and isinstance(data["messages"], list):
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
enabled_roles = litellm.guardrail_name_config_map["prompt_injection"].enabled_roles
|
enabled_roles = litellm.guardrail_name_config_map[
|
||||||
lakera_input_dict = {role: None for role in INPUT_POSITIONING_MAP.keys()}
|
"prompt_injection"
|
||||||
system_message = None
|
].enabled_roles
|
||||||
tool_call_messages = []
|
if enabled_roles is None:
|
||||||
|
enabled_roles = default_roles
|
||||||
|
lakera_input_dict: Dict = {
|
||||||
|
role: None for role in INPUT_POSITIONING_MAP.keys()
|
||||||
|
}
|
||||||
|
system_message = None
|
||||||
|
tool_call_messages: List = []
|
||||||
for message in data["messages"]:
|
for message in data["messages"]:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
if role in enabled_roles:
|
if role in enabled_roles:
|
||||||
if "tool_calls" in message:
|
if "tool_calls" in message:
|
||||||
tool_call_messages = [*tool_call_messages, *message["tool_calls"]]
|
tool_call_messages = [
|
||||||
if role == Role.SYSTEM.value: # we need this for later
|
*tool_call_messages,
|
||||||
|
*message["tool_calls"],
|
||||||
|
]
|
||||||
|
if role == Role.SYSTEM.value: # we need this for later
|
||||||
system_message = message
|
system_message = message
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lakera_input_dict[role] = {"role": role, "content": message.get('content')}
|
lakera_input_dict[role] = {
|
||||||
|
"role": role,
|
||||||
|
"content": message.get("content"),
|
||||||
|
}
|
||||||
|
|
||||||
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
|
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
|
||||||
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
||||||
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
||||||
# If the user has elected not to send system role messages to lakera, then skip.
|
# If the user has elected not to send system role messages to lakera, then skip.
|
||||||
if system_message is not None:
|
if system_message is not None:
|
||||||
if not litellm.add_function_to_prompt:
|
if not litellm.add_function_to_prompt:
|
||||||
content = system_message.get("content")
|
content = system_message.get("content")
|
||||||
function_input = []
|
function_input = []
|
||||||
for tool_call in tool_call_messages:
|
for tool_call in tool_call_messages:
|
||||||
if "function" in tool_call:
|
if "function" in tool_call:
|
||||||
function_input.append(tool_call["function"]["arguments"])
|
function_input.append(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
if len(function_input) > 0:
|
if len(function_input) > 0:
|
||||||
content += " Function Input: " + ' '.join(function_input)
|
content += " Function Input: " + " ".join(function_input)
|
||||||
lakera_input_dict[Role.SYSTEM.value] = {'role': Role.SYSTEM.value, 'content': content}
|
lakera_input_dict[Role.SYSTEM.value] = {
|
||||||
|
"role": Role.SYSTEM.value,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
lakera_input = [
|
||||||
lakera_input = [v for k, v in sorted(lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]) if v is not None]
|
v
|
||||||
|
for k, v in sorted(
|
||||||
|
lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]
|
||||||
|
)
|
||||||
|
if v is not None
|
||||||
|
]
|
||||||
if len(lakera_input) == 0:
|
if len(lakera_input) == 0:
|
||||||
verbose_proxy_logger.debug("Skipping lakera prompt injection, no roles with messages found")
|
verbose_proxy_logger.debug(
|
||||||
|
"Skipping lakera prompt injection, no roles with messages found"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif "input" in data and isinstance(data["input"], str):
|
elif "input" in data and isinstance(data["input"], str):
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
# This file runs a health check for the LLM, used on litellm/proxy
|
# This file runs a health check for the LLM, used on litellm/proxy
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import logging
|
|
||||||
from litellm._logging import print_verbose
|
from litellm._logging import print_verbose
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +15,7 @@ ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]
|
||||||
|
|
||||||
MINIMAL_DISPLAY_PARAMS = ["model"]
|
MINIMAL_DISPLAY_PARAMS = ["model"]
|
||||||
|
|
||||||
|
|
||||||
def _get_random_llm_message():
|
def _get_random_llm_message():
|
||||||
"""
|
"""
|
||||||
Get a random message from the LLM.
|
Get a random message from the LLM.
|
||||||
|
@ -25,7 +25,7 @@ def _get_random_llm_message():
|
||||||
return [{"role": "user", "content": random.choice(messages)}]
|
return [{"role": "user", "content": random.choice(messages)}]
|
||||||
|
|
||||||
|
|
||||||
def _clean_endpoint_data(endpoint_data: dict, details: bool):
|
def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
|
||||||
"""
|
"""
|
||||||
Clean the endpoint data for display to users.
|
Clean the endpoint data for display to users.
|
||||||
"""
|
"""
|
||||||
|
@ -36,7 +36,7 @@ def _clean_endpoint_data(endpoint_data: dict, details: bool):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _perform_health_check(model_list: list, details: bool):
|
async def _perform_health_check(model_list: list, details: Optional[bool] = True):
|
||||||
"""
|
"""
|
||||||
Perform a health check for each model in the list.
|
Perform a health check for each model in the list.
|
||||||
"""
|
"""
|
||||||
|
@ -64,9 +64,13 @@ async def _perform_health_check(model_list: list, details: bool):
|
||||||
litellm_params = model["litellm_params"]
|
litellm_params = model["litellm_params"]
|
||||||
|
|
||||||
if isinstance(is_healthy, dict) and "error" not in is_healthy:
|
if isinstance(is_healthy, dict) and "error" not in is_healthy:
|
||||||
healthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details))
|
healthy_endpoints.append(
|
||||||
|
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
|
||||||
|
)
|
||||||
elif isinstance(is_healthy, dict):
|
elif isinstance(is_healthy, dict):
|
||||||
unhealthy_endpoints.append(_clean_endpoint_data({**litellm_params, **is_healthy}, details))
|
unhealthy_endpoints.append(
|
||||||
|
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
|
unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
|
||||||
|
|
||||||
|
@ -74,7 +78,10 @@ async def _perform_health_check(model_list: list, details: bool):
|
||||||
|
|
||||||
|
|
||||||
async def perform_health_check(
|
async def perform_health_check(
|
||||||
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None, details: Optional[bool] = True
|
model_list: list,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
cli_model: Optional[str] = None,
|
||||||
|
details: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a health check on the system.
|
Perform a health check on the system.
|
||||||
|
@ -98,6 +105,8 @@ async def perform_health_check(
|
||||||
_new_model_list = [x for x in model_list if x["model_name"] == model]
|
_new_model_list = [x for x in model_list if x["model_name"] == model]
|
||||||
model_list = _new_model_list
|
model_list = _new_model_list
|
||||||
|
|
||||||
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list, details)
|
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
|
||||||
|
model_list, details
|
||||||
|
)
|
||||||
|
|
||||||
return healthy_endpoints, unhealthy_endpoints
|
return healthy_endpoints, unhealthy_endpoints
|
||||||
|
|
|
@ -18,12 +18,15 @@ litellm_settings:
|
||||||
default_on: true
|
default_on: true
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
||||||
default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER];
|
|
||||||
|
default_roles = [Role.SYSTEM, Role.ASSISTANT, Role.USER]
|
||||||
|
|
||||||
|
|
||||||
class GuardrailItemSpec(TypedDict, total=False):
|
class GuardrailItemSpec(TypedDict, total=False):
|
||||||
callbacks: Required[List[str]]
|
callbacks: Required[List[str]]
|
||||||
|
@ -37,7 +40,7 @@ class GuardrailItem(BaseModel):
|
||||||
default_on: bool
|
default_on: bool
|
||||||
logging_only: Optional[bool]
|
logging_only: Optional[bool]
|
||||||
guardrail_name: str
|
guardrail_name: str
|
||||||
enabled_roles: List[Role]
|
enabled_roles: Optional[List[Role]]
|
||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -46,7 +49,7 @@ class GuardrailItem(BaseModel):
|
||||||
guardrail_name: str,
|
guardrail_name: str,
|
||||||
default_on: bool = False,
|
default_on: bool = False,
|
||||||
logging_only: Optional[bool] = None,
|
logging_only: Optional[bool] = None,
|
||||||
enabled_roles: List[Role] = default_roles,
|
enabled_roles: Optional[List[Role]] = default_roles,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue