mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(proxy_server.py): get master key from environment, if not set in … (#9617)
* fix(proxy_server.py): get master key from environment, if not set in general settings or general settings not set at all * test: mark flaky test * test(test_proxy_server.py): mock prisma client * ci: add new github workflow for testing just the mock tests * fix: fix linting error * ci(conftest.py): add conftest.py to isolate proxy tests * build(pyproject.toml): add respx to dev dependencies * build(pyproject.toml): add prisma to dev dependencies * test: fix mock prompt management tests to use a mock anthropic key * ci(test-litellm.yml): parallelize mock testing make it run faster * build(pyproject.toml): add hypercorn as dev dep * build(pyproject.toml): separate proxy vs. core dev dependencies make it easier for non-proxy contributors to run tests locally - e.g. no need to install hypercorn * ci(test-litellm.yml): pin python version * test(test_rerank.py): move test - cannot be mocked, requires aws credentials for e2e testing * ci: add thank you message to ci * test: add mock env var to test * test: add autouse to tests * test: test mock env vars for e2e tests
This commit is contained in:
parent
1f2bbda11d
commit
205db622bf
13 changed files with 477 additions and 282 deletions
35
.github/workflows/test-litellm.yml
vendored
Normal file
35
.github/workflows/test-litellm.yml
vendored
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
name: LiteLLM Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Thank You Message
|
||||||
|
run: |
|
||||||
|
echo "### 🙏 Thank you for contributing to LiteLLM!" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Your PR is being tested now. We appreciate your help in making LiteLLM better!" >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --with dev,proxy-dev --extras proxy
|
||||||
|
poetry run pip install pytest-xdist
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
poetry run pytest tests/litellm -x -vv -n 4
|
3
Makefile
3
Makefile
|
@ -14,6 +14,9 @@ help:
|
||||||
install-dev:
|
install-dev:
|
||||||
poetry install --with dev
|
poetry install --with dev
|
||||||
|
|
||||||
|
install-proxy-dev:
|
||||||
|
poetry install --with dev,proxy-dev
|
||||||
|
|
||||||
lint: install-dev
|
lint: install-dev
|
||||||
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
|
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
|
||||||
cd litellm && poetry run mypy . --ignore-missing-imports
|
cd litellm && poetry run mypy . --ignore-missing-imports
|
||||||
|
|
|
@ -463,6 +463,8 @@ async def proxy_startup_event(app: FastAPI):
|
||||||
if premium_user is False:
|
if premium_user is False:
|
||||||
premium_user = _license_check.is_premium()
|
premium_user = _license_check.is_premium()
|
||||||
|
|
||||||
|
## CHECK MASTER KEY IN ENVIRONMENT ##
|
||||||
|
master_key = get_secret_str("LITELLM_MASTER_KEY")
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||||
|
|
444
poetry.lock
generated
444
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -99,6 +99,11 @@ mypy = "^1.0"
|
||||||
pytest = "^7.4.3"
|
pytest = "^7.4.3"
|
||||||
pytest-mock = "^3.12.0"
|
pytest-mock = "^3.12.0"
|
||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
|
respx = "^0.20.2"
|
||||||
|
|
||||||
|
[tool.poetry.group.proxy-dev.dependencies]
|
||||||
|
prisma = "0.11.0"
|
||||||
|
hypercorn = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core", "wheel"]
|
requires = ["poetry-core", "wheel"]
|
||||||
|
|
63
tests/litellm/conftest.py
Normal file
63
tests/litellm/conftest.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# conftest.py
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_and_teardown():
|
||||||
|
"""
|
||||||
|
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||||
|
"""
|
||||||
|
curr_dir = os.getcwd() # Get the current working directory
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the project directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
importlib.reload(litellm)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
|
||||||
|
import litellm.proxy.proxy_server
|
||||||
|
|
||||||
|
importlib.reload(litellm.proxy.proxy_server)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reloading litellm.proxy.proxy_server: {e}")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
print(litellm)
|
||||||
|
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Teardown code (executes after the yield point)
|
||||||
|
loop.close() # Close the loop created earlier
|
||||||
|
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||||
|
custom_logger_tests = [
|
||||||
|
item for item in items if "custom_logger" in item.parent.name
|
||||||
|
]
|
||||||
|
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||||
|
|
||||||
|
# Sort tests based on their names
|
||||||
|
custom_logger_tests.sort(key=lambda x: x.name)
|
||||||
|
other_tests.sort(key=lambda x: x.name)
|
||||||
|
|
||||||
|
# Reorder the items list
|
||||||
|
items[:] = custom_logger_tests + other_tests
|
|
@ -19,6 +19,11 @@ from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import StandardCallbackDynamicParams
|
from litellm.types.utils import StandardCallbackDynamicParams
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_anthropic_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
|
||||||
|
|
||||||
|
|
||||||
class TestCustomPromptManagement(CustomPromptManagement):
|
class TestCustomPromptManagement(CustomPromptManagement):
|
||||||
def get_chat_completion_prompt(
|
def get_chat_completion_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -50,7 +55,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_custom_prompt_management_with_prompt_id():
|
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
|
||||||
custom_prompt_management = TestCustomPromptManagement()
|
custom_prompt_management = TestCustomPromptManagement()
|
||||||
litellm.callbacks = [custom_prompt_management]
|
litellm.callbacks = [custom_prompt_management]
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,11 @@ def client():
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def add_anthropic_api_key_to_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ui_view_spend_logs_with_user_id(client, monkeypatch):
|
async def test_ui_view_spend_logs_with_user_id(client, monkeypatch):
|
||||||
# Mock data for the test
|
# Mock data for the test
|
||||||
|
@ -500,7 +505,7 @@ class TestSpendLogsPayload:
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_spend_logs_payload_success_log_with_api_base(self):
|
async def test_spend_logs_payload_success_log_with_api_base(self, monkeypatch):
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
litellm.callbacks = [_ProxyDBLogger(message_logging=False)]
|
litellm.callbacks = [_ProxyDBLogger(message_logging=False)]
|
||||||
|
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
import click
|
import click
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
@ -74,3 +75,92 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
||||||
call[0] for call in mock_proxy_config.get_credentials.mock_calls
|
call[0] for call in mock_proxy_config.get_credentials.mock_calls
|
||||||
]
|
]
|
||||||
assert len(mock_scheduler_calls) > 0
|
assert len(mock_scheduler_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Mock Prisma
|
||||||
|
class MockPrisma:
|
||||||
|
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
|
||||||
|
self.database_url = database_url
|
||||||
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
self.http_client = http_client
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
mock_prisma = MockPrisma()
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
|
||||||
|
return_value=mock_prisma,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that master_key is correctly loaded from either config.yaml or environment variables
|
||||||
|
"""
|
||||||
|
import yaml
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
# Import happens here - this is when the module probably reads the config path
|
||||||
|
from litellm.proxy.proxy_server import proxy_startup_event
|
||||||
|
|
||||||
|
# Mock the Prisma import
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
|
||||||
|
|
||||||
|
# Create test app
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Test Case 1: Master key from config.yaml
|
||||||
|
test_master_key = "sk-12345"
|
||||||
|
test_config = {"general_settings": {"master_key": test_master_key}}
|
||||||
|
|
||||||
|
# Create a temporary config file
|
||||||
|
config_path = tmp_path / "config.yaml"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
yaml.dump(test_config, f)
|
||||||
|
|
||||||
|
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
|
||||||
|
# Second setting of CONFIG_FILE_PATH to a different value
|
||||||
|
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
|
||||||
|
print(f"config_path: {config_path}")
|
||||||
|
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
|
||||||
|
async with proxy_startup_event(app):
|
||||||
|
from litellm.proxy.proxy_server import master_key
|
||||||
|
|
||||||
|
assert master_key == test_master_key
|
||||||
|
|
||||||
|
# Test Case 2: Master key from environment variable
|
||||||
|
test_env_master_key = "sk-67890"
|
||||||
|
|
||||||
|
# Create empty config
|
||||||
|
empty_config = {"general_settings": {}}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
yaml.dump(empty_config, f)
|
||||||
|
|
||||||
|
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
|
||||||
|
print("test_env_master_key: {}".format(test_env_master_key))
|
||||||
|
async with proxy_startup_event(app):
|
||||||
|
from litellm.proxy.proxy_server import master_key
|
||||||
|
|
||||||
|
assert master_key == test_env_master_key
|
||||||
|
|
||||||
|
# Test Case 3: Master key with os.environ prefix
|
||||||
|
test_resolved_key = "sk-resolved-key"
|
||||||
|
test_config_with_prefix = {
|
||||||
|
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create config with os.environ prefix
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
yaml.dump(test_config_with_prefix, f)
|
||||||
|
|
||||||
|
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
|
||||||
|
async with proxy_startup_event(app):
|
||||||
|
from litellm.proxy.proxy_server import master_key
|
||||||
|
|
||||||
|
assert master_key == test_resolved_key
|
||||||
|
|
|
@ -1,51 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from litellm import rerank
|
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
|
||||||
|
|
||||||
|
|
||||||
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
|
||||||
mock_response = MagicMock()
|
|
||||||
|
|
||||||
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
|
||||||
args = {
|
|
||||||
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
|
||||||
"query": "hello",
|
|
||||||
"documents": ["hello", "world"],
|
|
||||||
}
|
|
||||||
|
|
||||||
def return_val():
|
|
||||||
return {
|
|
||||||
"results": [
|
|
||||||
{"index": 0, "relevanceScore": 0.6716859340667725},
|
|
||||||
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_response.json = return_val
|
|
||||||
mock_response.headers = {"key": "value"}
|
|
||||||
mock_response.status_code = 200
|
|
||||||
|
|
||||||
client = HTTPHandler()
|
|
||||||
|
|
||||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
|
||||||
rerank(
|
|
||||||
model=args["model"],
|
|
||||||
query=args["query"],
|
|
||||||
documents=args["documents"],
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
|
||||||
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
|
||||||
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
|
|
@ -14,6 +14,15 @@ from unittest.mock import MagicMock, patch
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def add_api_keys_to_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-api03-1234567890")
|
||||||
|
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "my-fake-aws-access-key-id")
|
||||||
|
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "my-fake-aws-secret-access-key")
|
||||||
|
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_api_response():
|
def openai_api_response():
|
||||||
mock_response_data = {
|
mock_response_data = {
|
||||||
|
@ -130,7 +139,8 @@ def test_completion_missing_role(openai_api_response):
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_url_with_format_param(model, sync_mode):
|
async def test_url_with_format_param(model, sync_mode, monkeypatch):
|
||||||
|
|
||||||
from litellm import acompletion, completion
|
from litellm import acompletion, completion
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ litellm.num_retries = 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("stream", [True, False])
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_cohere_citations(stream):
|
async def test_chat_completion_cohere_citations(stream):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -481,3 +481,42 @@ def test_rerank_cohere_api():
|
||||||
assert response.results[0]["document"]["text"] is not None
|
assert response.results[0]["document"]["text"] is not None
|
||||||
assert response.results[0]["document"]["text"] == "hello"
|
assert response.results[0]["document"]["text"] == "hello"
|
||||||
assert response.results[1]["document"]["text"] == "world"
|
assert response.results[1]["document"]["text"] == "world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
||||||
|
args = {
|
||||||
|
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
||||||
|
"query": "hello",
|
||||||
|
"documents": ["hello", "world"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevanceScore": 0.6716859340667725},
|
||||||
|
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.headers = {"key": "value"}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||||
|
litellm.rerank(
|
||||||
|
model=args["model"],
|
||||||
|
query=args["query"],
|
||||||
|
documents=args["documents"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
||||||
|
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
||||||
|
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue