forked from phoenix/litellm-mirror
refactor a bit
This commit is contained in:
parent
6ff863ee00
commit
b83f47e941
2 changed files with 8 additions and 25 deletions
|
@ -1,21 +1,12 @@
|
|||
# What is this?
|
||||
## This tests the Lakera AI integration
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException
|
||||
import litellm.llms
|
||||
import litellm.llms.custom_httpx
|
||||
import litellm.llms.custom_httpx.http_handler
|
||||
import litellm.llms.custom_httpx.httpx_handler
|
||||
from litellm.types.guardrails import GuardrailItem
|
||||
|
||||
load_dotenv()
|
||||
|
@ -29,15 +20,14 @@ import logging
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import Router, mock_completion
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||
from unittest.mock import patch, MagicMock
|
||||
from litellm.proxy.utils import hash_token
|
||||
from unittest.mock import patch
|
||||
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
@ -59,7 +49,6 @@ async def test_lakera_prompt_injection_detection():
|
|||
_api_key = "sk-12345"
|
||||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||
local_cache = DualCache()
|
||||
|
||||
try:
|
||||
await lakera_ai.async_moderation_hook(
|
||||
|
@ -94,7 +83,7 @@ async def test_lakera_safe_prompt():
|
|||
_api_key = "sk-12345"
|
||||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||
local_cache = DualCache()
|
||||
|
||||
await lakera_ai.async_moderation_hook(
|
||||
data={
|
||||
"messages": [
|
||||
|
|
|
@ -4157,7 +4157,11 @@ def get_formatted_prompt(
|
|||
for c in content:
|
||||
if c["type"] == "text":
|
||||
prompt += c["text"]
|
||||
prompt += get_tool_call_function_args(message)
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
function_arguments = tool_call["function"]["arguments"]
|
||||
prompt += function_arguments
|
||||
elif call_type == "text_completion":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "embedding" or call_type == "moderation":
|
||||
|
@ -4173,16 +4177,6 @@ def get_formatted_prompt(
|
|||
prompt = data["prompt"]
|
||||
return prompt
|
||||
|
||||
def get_tool_call_function_args(message: dict) -> str:
|
||||
all_args = ""
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
all_args += tool_call["function"]["arguments"]
|
||||
|
||||
return all_args
|
||||
|
||||
|
||||
def get_response_string(response_obj: ModelResponse) -> str:
|
||||
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue