forked from phoenix/litellm-mirror
refactor a bit
This commit is contained in:
parent
6ff863ee00
commit
b83f47e941
2 changed files with 8 additions and 25 deletions
|
@ -1,21 +1,12 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This tests the Lakera AI integration
|
## This tests the Lakera AI integration
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import litellm.llms
|
|
||||||
import litellm.llms.custom_httpx
|
|
||||||
import litellm.llms.custom_httpx.http_handler
|
|
||||||
import litellm.llms.custom_httpx.httpx_handler
|
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -29,15 +20,14 @@ import logging
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router, mock_completion
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
_ENTERPRISE_lakeraAI_Moderation,
|
_ENTERPRISE_lakeraAI_Moderation,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
from litellm.proxy.utils import hash_token
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch
|
||||||
|
|
||||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
@ -59,7 +49,6 @@ async def test_lakera_prompt_injection_detection():
|
||||||
_api_key = "sk-12345"
|
_api_key = "sk-12345"
|
||||||
_api_key = hash_token("sk-12345")
|
_api_key = hash_token("sk-12345")
|
||||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
local_cache = DualCache()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await lakera_ai.async_moderation_hook(
|
await lakera_ai.async_moderation_hook(
|
||||||
|
@ -94,7 +83,7 @@ async def test_lakera_safe_prompt():
|
||||||
_api_key = "sk-12345"
|
_api_key = "sk-12345"
|
||||||
_api_key = hash_token("sk-12345")
|
_api_key = hash_token("sk-12345")
|
||||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
local_cache = DualCache()
|
|
||||||
await lakera_ai.async_moderation_hook(
|
await lakera_ai.async_moderation_hook(
|
||||||
data={
|
data={
|
||||||
"messages": [
|
"messages": [
|
||||||
|
|
|
@ -4157,7 +4157,11 @@ def get_formatted_prompt(
|
||||||
for c in content:
|
for c in content:
|
||||||
if c["type"] == "text":
|
if c["type"] == "text":
|
||||||
prompt += c["text"]
|
prompt += c["text"]
|
||||||
prompt += get_tool_call_function_args(message)
|
if "tool_calls" in message:
|
||||||
|
for tool_call in message["tool_calls"]:
|
||||||
|
if "function" in tool_call:
|
||||||
|
function_arguments = tool_call["function"]["arguments"]
|
||||||
|
prompt += function_arguments
|
||||||
elif call_type == "text_completion":
|
elif call_type == "text_completion":
|
||||||
prompt = data["prompt"]
|
prompt = data["prompt"]
|
||||||
elif call_type == "embedding" or call_type == "moderation":
|
elif call_type == "embedding" or call_type == "moderation":
|
||||||
|
@ -4173,16 +4177,6 @@ def get_formatted_prompt(
|
||||||
prompt = data["prompt"]
|
prompt = data["prompt"]
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def get_tool_call_function_args(message: dict) -> str:
|
|
||||||
all_args = ""
|
|
||||||
if "tool_calls" in message:
|
|
||||||
for tool_call in message["tool_calls"]:
|
|
||||||
if "function" in tool_call:
|
|
||||||
all_args += tool_call["function"]["arguments"]
|
|
||||||
|
|
||||||
return all_args
|
|
||||||
|
|
||||||
|
|
||||||
def get_response_string(response_obj: ModelResponse) -> str:
|
def get_response_string(response_obj: ModelResponse) -> str:
|
||||||
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue