mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* test_anthropic_cache_control_hook_system_message * test_anthropic_cache_control_hook.py * should_run_prompt_management_hooks * fix should_run_prompt_management_hooks * test_anthropic_cache_control_hook_specific_index * fix test * fix linting errors * ChatCompletionCachedContent * initial commit for cache control * fixes ui design * fix inserting cache_control_injection_points * fix entering cache control points * fixes for using cache control on ui + backend * update cache control settings on edit model page * fix init custom logger compatible class * fix linting errors * fix linting errors * fix get_chat_completion_prompt
137 lines
5.1 KiB
Python
137 lines
5.1 KiB
Python
import datetime
|
|
import json
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from typing import List, Optional, Tuple
|
|
from unittest.mock import ANY, MagicMock, Mock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system-path
|
|
import litellm
|
|
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import StandardCallbackDynamicParams
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_anthropic_api_key(monkeypatch):
|
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
|
|
|
|
|
|
class TestCustomPromptManagement(CustomPromptManagement):
|
|
def get_chat_completion_prompt(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
non_default_params: dict,
|
|
prompt_id: Optional[str],
|
|
prompt_variables: Optional[dict],
|
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
|
print(
|
|
"TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
|
|
prompt_id,
|
|
)
|
|
if prompt_id == "test_prompt_id":
|
|
messages = [
|
|
{"role": "user", "content": "This is the prompt for test_prompt_id"},
|
|
]
|
|
return model, messages, non_default_params
|
|
elif prompt_id == "prompt_with_variables":
|
|
content = "Hello, {name}! You are {age} years old and live in {city}."
|
|
content_with_variables = content.format(**(prompt_variables or {}))
|
|
messages = [
|
|
{"role": "user", "content": content_with_variables},
|
|
]
|
|
return model, messages, non_default_params
|
|
else:
|
|
return model, messages, non_default_params
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
|
|
custom_prompt_management = TestCustomPromptManagement()
|
|
litellm.callbacks = [custom_prompt_management]
|
|
|
|
# Mock AsyncHTTPHandler.post method
|
|
client = AsyncHTTPHandler()
|
|
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
|
await litellm.acompletion(
|
|
model="anthropic/claude-3-5-sonnet",
|
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
client=client,
|
|
prompt_id="test_prompt_id",
|
|
)
|
|
|
|
mock_post.assert_called_once()
|
|
print(mock_post.call_args.kwargs)
|
|
request_body = mock_post.call_args.kwargs["json"]
|
|
print("request_body: ", json.dumps(request_body, indent=4))
|
|
|
|
assert request_body["model"] == "claude-3-5-sonnet"
|
|
# the message gets applied to the prompt from the custom prompt management callback
|
|
assert (
|
|
request_body["messages"][0]["content"][0]["text"]
|
|
== "This is the prompt for test_prompt_id"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():
|
|
custom_prompt_management = TestCustomPromptManagement()
|
|
litellm.callbacks = [custom_prompt_management]
|
|
|
|
# Mock AsyncHTTPHandler.post method
|
|
client = AsyncHTTPHandler()
|
|
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
|
await litellm.acompletion(
|
|
model="anthropic/claude-3-5-sonnet",
|
|
messages=[],
|
|
client=client,
|
|
prompt_id="prompt_with_variables",
|
|
prompt_variables={"name": "John", "age": 30, "city": "New York"},
|
|
)
|
|
|
|
mock_post.assert_called_once()
|
|
print(mock_post.call_args.kwargs)
|
|
request_body = mock_post.call_args.kwargs["json"]
|
|
print("request_body: ", json.dumps(request_body, indent=4))
|
|
|
|
assert request_body["model"] == "claude-3-5-sonnet"
|
|
# the message gets applied to the prompt from the custom prompt management callback
|
|
assert (
|
|
request_body["messages"][0]["content"][0]["text"]
|
|
== "Hello, John! You are 30 years old and live in New York."
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_prompt_management_without_prompt_id():
|
|
custom_prompt_management = TestCustomPromptManagement()
|
|
litellm.callbacks = [custom_prompt_management]
|
|
|
|
# Mock AsyncHTTPHandler.post method
|
|
client = AsyncHTTPHandler()
|
|
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
|
await litellm.acompletion(
|
|
model="anthropic/claude-3-5-sonnet",
|
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
client=client,
|
|
)
|
|
|
|
mock_post.assert_called_once()
|
|
print(mock_post.call_args.kwargs)
|
|
request_body = mock_post.call_args.kwargs["json"]
|
|
print("request_body: ", json.dumps(request_body, indent=4))
|
|
|
|
assert request_body["model"] == "claude-3-5-sonnet"
|
|
# the message does not get applied to the prompt from the custom prompt management callback since we did not pass a prompt_id
|
|
assert (
|
|
request_body["messages"][0]["content"][0]["text"] == "Hello, how are you?"
|
|
)
|