fix(proxy_server.py): get master key from environment, if not set in … (#9617)

* fix(proxy_server.py): get master key from environment, if not set in general settings or general settings not set at all

* test: mark flaky test

* test(test_proxy_server.py): mock prisma client

* ci: add new github workflow for testing just the mock tests

* fix: fix linting error

* ci(conftest.py): add conftest.py to isolate proxy tests

* build(pyproject.toml): add respx to dev dependencies

* build(pyproject.toml): add prisma to dev dependencies

* test: fix mock prompt management tests to use a mock anthropic key

* ci(test-litellm.yml): parallelize mock testing

make it run faster

* build(pyproject.toml): add hypercorn as dev dep

* build(pyproject.toml): separate proxy vs. core dev dependencies

make it easier for non-proxy contributors to run tests locally - e.g. no need to install hypercorn

* ci(test-litellm.yml): pin python version

* test(test_rerank.py): move test - cannot be mocked, requires aws credentials for e2e testing

* ci: add thank you message to ci

* test: add mock env var to test

* test: add autouse to tests

* test: test mock env vars for e2e tests
This commit is contained in:
Krish Dholakia 2025-03-28 12:32:04 -07:00 committed by Nicholas Grabar
parent 1f2bbda11d
commit 205db622bf
13 changed files with 477 additions and 282 deletions

35
.github/workflows/test-litellm.yml vendored Normal file
View file

@ -0,0 +1,35 @@
name: LiteLLM Tests
on:
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Thank You Message
run: |
echo "### 🙏 Thank you for contributing to LiteLLM!" >> $GITHUB_STEP_SUMMARY
echo "Your PR is being tested now. We appreciate your help in making LiteLLM better!" >> $GITHUB_STEP_SUMMARY
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
poetry install --with dev,proxy-dev --extras proxy
poetry run pip install pytest-xdist
- name: Run tests
run: |
poetry run pytest tests/litellm -x -vv -n 4

View file

@ -14,6 +14,9 @@ help:
install-dev: install-dev:
poetry install --with dev poetry install --with dev
install-proxy-dev:
poetry install --with dev,proxy-dev
lint: install-dev lint: install-dev
poetry run pip install types-requests types-setuptools types-redis types-PyYAML poetry run pip install types-requests types-setuptools types-redis types-PyYAML
cd litellm && poetry run mypy . --ignore-missing-imports cd litellm && poetry run mypy . --ignore-missing-imports

View file

@ -463,6 +463,8 @@ async def proxy_startup_event(app: FastAPI):
if premium_user is False: if premium_user is False:
premium_user = _license_check.is_premium() premium_user = _license_check.is_premium()
## CHECK MASTER KEY IN ENVIRONMENT ##
master_key = get_secret_str("LITELLM_MASTER_KEY")
### LOAD CONFIG ### ### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")

444
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -99,6 +99,11 @@ mypy = "^1.0"
pytest = "^7.4.3" pytest = "^7.4.3"
pytest-mock = "^3.12.0" pytest-mock = "^3.12.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
respx = "^0.20.2"
[tool.poetry.group.proxy-dev.dependencies]
prisma = "0.11.0"
hypercorn = "^0.15.0"
[build-system] [build-system]
requires = ["poetry-core", "wheel"] requires = ["poetry-core", "wheel"]

63
tests/litellm/conftest.py Normal file
View file

@ -0,0 +1,63 @@
# conftest.py
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
from litellm import Router
importlib.reload(litellm)
try:
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
import litellm.proxy.proxy_server
importlib.reload(litellm.proxy.proxy_server)
except Exception as e:
print(f"Error reloading litellm.proxy.proxy_server: {e}")
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View file

@ -19,6 +19,11 @@ from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams from litellm.types.utils import StandardCallbackDynamicParams
@pytest.fixture(autouse=True)
def setup_anthropic_api_key(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
class TestCustomPromptManagement(CustomPromptManagement): class TestCustomPromptManagement(CustomPromptManagement):
def get_chat_completion_prompt( def get_chat_completion_prompt(
self, self,
@ -50,7 +55,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id(): async def test_custom_prompt_management_with_prompt_id(monkeypatch):
custom_prompt_management = TestCustomPromptManagement() custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management] litellm.callbacks = [custom_prompt_management]

View file

@ -26,6 +26,11 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture(autouse=True)
def add_anthropic_api_key_to_env(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ui_view_spend_logs_with_user_id(client, monkeypatch): async def test_ui_view_spend_logs_with_user_id(client, monkeypatch):
# Mock data for the test # Mock data for the test
@ -500,7 +505,7 @@ class TestSpendLogsPayload:
return mock_response return mock_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_spend_logs_payload_success_log_with_api_base(self): async def test_spend_logs_payload_success_log_with_api_base(self, monkeypatch):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.callbacks = [_ProxyDBLogger(message_logging=False)] litellm.callbacks = [_ProxyDBLogger(message_logging=False)]

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import click import click
import httpx import httpx
import pytest import pytest
import yaml
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -74,3 +75,92 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
call[0] for call in mock_proxy_config.get_credentials.mock_calls call[0] for call in mock_proxy_config.get_credentials.mock_calls
] ]
assert len(mock_scheduler_calls) > 0 assert len(mock_scheduler_calls) > 0
# Mock Prisma
class MockPrisma:
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
self.database_url = database_url
self.proxy_logging_obj = proxy_logging_obj
self.http_client = http_client
async def connect(self):
pass
async def disconnect(self):
pass
mock_prisma = MockPrisma()
@patch(
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
return_value=mock_prisma,
)
@pytest.mark.asyncio
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
"""
Test that master_key is correctly loaded from either config.yaml or environment variables
"""
import yaml
from fastapi import FastAPI
# Import happens here - this is when the module probably reads the config path
from litellm.proxy.proxy_server import proxy_startup_event
# Mock the Prisma import
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
# Create test app
app = FastAPI()
# Test Case 1: Master key from config.yaml
test_master_key = "sk-12345"
test_config = {"general_settings": {"master_key": test_master_key}}
# Create a temporary config file
config_path = tmp_path / "config.yaml"
with open(config_path, "w") as f:
yaml.dump(test_config, f)
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
# Second setting of CONFIG_FILE_PATH to a different value
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
print(f"config_path: {config_path}")
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_master_key
# Test Case 2: Master key from environment variable
test_env_master_key = "sk-67890"
# Create empty config
empty_config = {"general_settings": {}}
with open(config_path, "w") as f:
yaml.dump(empty_config, f)
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
print("test_env_master_key: {}".format(test_env_master_key))
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_env_master_key
# Test Case 3: Master key with os.environ prefix
test_resolved_key = "sk-resolved-key"
test_config_with_prefix = {
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
}
# Create config with os.environ prefix
with open(config_path, "w") as f:
yaml.dump(test_config_with_prefix, f)
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_resolved_key

View file

@ -1,51 +0,0 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_rerank_infer_region_from_model_arn(monkeypatch):
mock_response = MagicMock()
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
args = {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
"query": "hello",
"documents": ["hello", "world"],
}
def return_val():
return {
"results": [
{"index": 0, "relevanceScore": 0.6716859340667725},
{"index": 1, "relevanceScore": 0.0004994205664843321},
]
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
client = HTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
rerank(
model=args["model"],
query=args["query"],
documents=args["documents"],
client=client,
)
mock_post.assert_called_once()
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
assert "us-west-2" in mock_post.call_args.kwargs["url"]
assert "us-east-1" not in mock_post.call_args.kwargs["url"]

View file

@ -14,6 +14,15 @@ from unittest.mock import MagicMock, patch
import litellm import litellm
@pytest.fixture(autouse=True)
def add_api_keys_to_env(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-api03-1234567890")
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "my-fake-aws-access-key-id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "my-fake-aws-secret-access-key")
monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture @pytest.fixture
def openai_api_response(): def openai_api_response():
mock_response_data = { mock_response_data = {
@ -130,7 +139,8 @@ def test_completion_missing_role(openai_api_response):
) )
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_url_with_format_param(model, sync_mode): async def test_url_with_format_param(model, sync_mode, monkeypatch):
from litellm import acompletion, completion from litellm import acompletion, completion
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -22,6 +22,7 @@ litellm.num_retries = 3
@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("stream", [True, False])
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_cohere_citations(stream): async def test_chat_completion_cohere_citations(stream):
try: try:

View file

@ -481,3 +481,42 @@ def test_rerank_cohere_api():
assert response.results[0]["document"]["text"] is not None assert response.results[0]["document"]["text"] is not None
assert response.results[0]["document"]["text"] == "hello" assert response.results[0]["document"]["text"] == "hello"
assert response.results[1]["document"]["text"] == "world" assert response.results[1]["document"]["text"] == "world"
def test_rerank_infer_region_from_model_arn(monkeypatch):
mock_response = MagicMock()
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
args = {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
"query": "hello",
"documents": ["hello", "world"],
}
def return_val():
return {
"results": [
{"index": 0, "relevanceScore": 0.6716859340667725},
{"index": 1, "relevanceScore": 0.0004994205664843321},
]
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
client = HTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
litellm.rerank(
model=args["model"],
query=args["query"],
documents=args["documents"],
client=client,
)
mock_post.assert_called_once()
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
assert "us-west-2" in mock_post.call_args.kwargs["url"]
assert "us-east-1" not in mock_post.call_args.kwargs["url"]