refactor a bit

This commit is contained in:
Vinnie Giarrusso 2024-07-16 12:19:31 -07:00
parent 6ff863ee00
commit b83f47e941
2 changed files with 8 additions and 25 deletions

View file

@ -1,21 +1,12 @@
# What is this? # What is this?
## This tests the Lakera AI integration ## This tests the Lakera AI integration
import asyncio
import os import os
import random
import sys import sys
import time
import traceback
import json import json
from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import HTTPException from fastapi import HTTPException
import litellm.llms
import litellm.llms.custom_httpx
import litellm.llms.custom_httpx.http_handler
import litellm.llms.custom_httpx.httpx_handler
from litellm.types.guardrails import GuardrailItem from litellm.types.guardrails import GuardrailItem
load_dotenv() load_dotenv()
@ -29,15 +20,14 @@ import logging
import pytest import pytest
import litellm import litellm
from litellm import Router, mock_completion
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation, _ENTERPRISE_lakeraAI_Moderation,
) )
from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy.utils import hash_token
from unittest.mock import patch, MagicMock from unittest.mock import patch
verbose_proxy_logger.setLevel(logging.DEBUG) verbose_proxy_logger.setLevel(logging.DEBUG)
@ -59,7 +49,6 @@ async def test_lakera_prompt_injection_detection():
_api_key = "sk-12345" _api_key = "sk-12345"
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
try: try:
await lakera_ai.async_moderation_hook( await lakera_ai.async_moderation_hook(
@ -94,7 +83,7 @@ async def test_lakera_safe_prompt():
_api_key = "sk-12345" _api_key = "sk-12345"
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
await lakera_ai.async_moderation_hook( await lakera_ai.async_moderation_hook(
data={ data={
"messages": [ "messages": [

View file

@ -4157,7 +4157,11 @@ def get_formatted_prompt(
for c in content: for c in content:
if c["type"] == "text": if c["type"] == "text":
prompt += c["text"] prompt += c["text"]
prompt += get_tool_call_function_args(message) if "tool_calls" in message:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
function_arguments = tool_call["function"]["arguments"]
prompt += function_arguments
elif call_type == "text_completion": elif call_type == "text_completion":
prompt = data["prompt"] prompt = data["prompt"]
elif call_type == "embedding" or call_type == "moderation": elif call_type == "embedding" or call_type == "moderation":
@ -4173,16 +4177,6 @@ def get_formatted_prompt(
prompt = data["prompt"] prompt = data["prompt"]
return prompt return prompt
def get_tool_call_function_args(message: dict) -> str:
all_args = ""
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
all_args += tool_call["function"]["arguments"]
return all_args
def get_response_string(response_obj: ModelResponse) -> str: def get_response_string(response_obj: ModelResponse) -> str:
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices _choices: List[Union[Choices, StreamingChoices]] = response_obj.choices