litellm-mirror/tests/litellm/integrations/test_custom_prompt_management.py
Ishaan Jaff c1a642ce20
[UI] Allow setting prompt cache_control_injection_points (#10000)
* test_anthropic_cache_control_hook_system_message

* test_anthropic_cache_control_hook.py

* should_run_prompt_management_hooks

* fix should_run_prompt_management_hooks

* test_anthropic_cache_control_hook_specific_index

* fix test

* fix linting errors

* ChatCompletionCachedContent

* initial commit for cache control

* fixes ui design

* fix inserting cache_control_injection_points

* fix entering cache control points

* fixes for using cache control on ui + backend

* update cache control settings on edit model page

* fix init custom logger compatible class

* fix linting errors

* fix linting errors

* fix get_chat_completion_prompt
2025-04-14 21:17:42 -07:00

137 lines
5.1 KiB
Python

import datetime
import json
import os
import sys
import unittest
from typing import List, Optional, Tuple
from unittest.mock import ANY, MagicMock, Mock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
@pytest.fixture(autouse=True)
def setup_anthropic_api_key(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
class TestCustomPromptManagement(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
print(
"TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
prompt_id,
)
if prompt_id == "test_prompt_id":
messages = [
{"role": "user", "content": "This is the prompt for test_prompt_id"},
]
return model, messages, non_default_params
elif prompt_id == "prompt_with_variables":
content = "Hello, {name}! You are {age} years old and live in {city}."
content_with_variables = content.format(**(prompt_variables or {}))
messages = [
{"role": "user", "content": content_with_variables},
]
return model, messages, non_default_params
else:
return model, messages, non_default_params
@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[{"role": "user", "content": "Hello, how are you?"}],
client=client,
prompt_id="test_prompt_id",
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message gets applied to the prompt from the custom prompt management callback
assert (
request_body["messages"][0]["content"][0]["text"]
== "This is the prompt for test_prompt_id"
)
@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[],
client=client,
prompt_id="prompt_with_variables",
prompt_variables={"name": "John", "age": 30, "city": "New York"},
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message gets applied to the prompt from the custom prompt management callback
assert (
request_body["messages"][0]["content"][0]["text"]
== "Hello, John! You are 30 years old and live in New York."
)
@pytest.mark.asyncio
async def test_custom_prompt_management_without_prompt_id():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[{"role": "user", "content": "Hello, how are you?"}],
client=client,
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message does not get applied to the prompt from the custom prompt management callback since we did not pass a prompt_id
assert (
request_body["messages"][0]["content"][0]["text"] == "Hello, how are you?"
)