litellm-mirror/tests/litellm/integrations/test_custom_prompt_management.py
Krish Dholakia 205db622bf fix(proxy_server.py): get master key from environment, if not set in … (#9617)
* fix(proxy_server.py): get master key from environment, if not set in general settings or general settings not set at all

* test: mark flaky test

* test(test_proxy_server.py): mock prisma client

* ci: add new github workflow for testing just the mock tests

* fix: fix linting error

* ci(conftest.py): add conftest.py to isolate proxy tests

* build(pyproject.toml): add respx to dev dependencies

* build(pyproject.toml): add prisma to dev dependencies

* test: fix mock prompt management tests to use a mock anthropic key

* ci(test-litellm.yml): parallelize mock testing

make it run faster

* build(pyproject.toml): add hypercorn as dev dep

* build(pyproject.toml): separate proxy vs. core dev dependencies

make it easier for non-proxy contributors to run tests locally - e.g. no need to install hypercorn

* ci(test-litellm.yml): pin python version

* test(test_rerank.py): move test - cannot be mocked, requires aws credentials for e2e testing

* ci: add thank you message to ci

* test: add mock env var to test

* test: add autouse to tests

* test: test mock env vars for e2e tests
2025-03-28 15:16:15 -07:00

137 lines
5.1 KiB
Python

import datetime
import json
import os
import sys
import unittest
from typing import List, Optional, Tuple
from unittest.mock import ANY, MagicMock, Mock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
@pytest.fixture(autouse=True)
def setup_anthropic_api_key(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
class TestCustomPromptManagement(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
print(
"TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
prompt_id,
)
if prompt_id == "test_prompt_id":
messages = [
{"role": "user", "content": "This is the prompt for test_prompt_id"},
]
return model, messages, non_default_params
elif prompt_id == "prompt_with_variables":
content = "Hello, {name}! You are {age} years old and live in {city}."
content_with_variables = content.format(**(prompt_variables or {}))
messages = [
{"role": "user", "content": content_with_variables},
]
return model, messages, non_default_params
else:
return model, messages, non_default_params
@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[{"role": "user", "content": "Hello, how are you?"}],
client=client,
prompt_id="test_prompt_id",
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message gets applied to the prompt from the custom prompt management callback
assert (
request_body["messages"][0]["content"][0]["text"]
== "This is the prompt for test_prompt_id"
)
@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[],
client=client,
prompt_id="prompt_with_variables",
prompt_variables={"name": "John", "age": 30, "city": "New York"},
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message gets applied to the prompt from the custom prompt management callback
assert (
request_body["messages"][0]["content"][0]["text"]
== "Hello, John! You are 30 years old and live in New York."
)
@pytest.mark.asyncio
async def test_custom_prompt_management_without_prompt_id():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
await litellm.acompletion(
model="anthropic/claude-3-5-sonnet",
messages=[{"role": "user", "content": "Hello, how are you?"}],
client=client,
)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
request_body = mock_post.call_args.kwargs["json"]
print("request_body: ", json.dumps(request_body, indent=4))
assert request_body["model"] == "claude-3-5-sonnet"
# the message does not get applied to the prompt from the custom prompt management callback since we did not pass a prompt_id
assert (
request_body["messages"][0]["content"][0]["text"] == "Hello, how are you?"
)