From b83f47e941487be93608c444a395be366c9b78e5 Mon Sep 17 00:00:00 2001 From: Vinnie Giarrusso Date: Tue, 16 Jul 2024 12:19:31 -0700 Subject: [PATCH] refactor a bit --- .../tests/test_lakera_ai_prompt_injection.py | 17 +++-------------- litellm/utils.py | 16 +++++----------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/litellm/tests/test_lakera_ai_prompt_injection.py b/litellm/tests/test_lakera_ai_prompt_injection.py index 455f3292b..57d7cffcc 100644 --- a/litellm/tests/test_lakera_ai_prompt_injection.py +++ b/litellm/tests/test_lakera_ai_prompt_injection.py @@ -1,21 +1,12 @@ # What is this? ## This tests the Lakera AI integration -import asyncio import os -import random import sys -import time -import traceback import json -from datetime import datetime from dotenv import load_dotenv from fastapi import HTTPException -import litellm.llms -import litellm.llms.custom_httpx -import litellm.llms.custom_httpx.http_handler -import litellm.llms.custom_httpx.httpx_handler from litellm.types.guardrails import GuardrailItem load_dotenv() @@ -29,15 +20,14 @@ import logging import pytest import litellm -from litellm import Router, mock_completion from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( _ENTERPRISE_lakeraAI_Moderation, ) -from litellm.proxy.utils import ProxyLogging, hash_token -from unittest.mock import patch, MagicMock +from litellm.proxy.utils import hash_token +from unittest.mock import patch verbose_proxy_logger.setLevel(logging.DEBUG) @@ -59,7 +49,6 @@ async def test_lakera_prompt_injection_detection(): _api_key = "sk-12345" _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) - local_cache = DualCache() try: await lakera_ai.async_moderation_hook( @@ -94,7 +83,7 @@ async def test_lakera_safe_prompt(): _api_key = "sk-12345" _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) - local_cache = DualCache() + await lakera_ai.async_moderation_hook( data={ "messages": [ diff --git a/litellm/utils.py b/litellm/utils.py index 399446c4b..88dac39bf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4157,7 +4157,11 @@ def get_formatted_prompt( for c in content: if c["type"] == "text": prompt += c["text"] - prompt += get_tool_call_function_args(message) + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + if "function" in tool_call: + function_arguments = tool_call["function"]["arguments"] + prompt += function_arguments elif call_type == "text_completion": prompt = data["prompt"] elif call_type == "embedding" or call_type == "moderation": @@ -4173,16 +4177,6 @@ def get_formatted_prompt( prompt = data["prompt"] return prompt -def get_tool_call_function_args(message: dict) -> str: - all_args = "" - if "tool_calls" in message: - for tool_call in message["tool_calls"]: - if "function" in tool_call: - all_args += tool_call["function"]["arguments"] - - return all_args - - def get_response_string(response_obj: ModelResponse) -> str: _choices: List[Union[Choices, StreamingChoices]] = response_obj.choices