mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(proxy_server.py): get master key from environment, if not set in … (#9617)
* fix(proxy_server.py): get master key from environment, if not set in general settings or general settings not set at all * test: mark flaky test * test(test_proxy_server.py): mock prisma client * ci: add new github workflow for testing just the mock tests * fix: fix linting error * ci(conftest.py): add conftest.py to isolate proxy tests * build(pyproject.toml): add respx to dev dependencies * build(pyproject.toml): add prisma to dev dependencies * test: fix mock prompt management tests to use a mock anthropic key * ci(test-litellm.yml): parallelize mock testing make it run faster * build(pyproject.toml): add hypercorn as dev dep * build(pyproject.toml): separate proxy vs. core dev dependencies make it easier for non-proxy contributors to run tests locally - e.g. no need to install hypercorn * ci(test-litellm.yml): pin python version * test(test_rerank.py): move test - cannot be mocked, requires aws credentials for e2e testing * ci: add thank you message to ci * test: add mock env var to test * test: add autouse to tests * test: test mock env vars for e2e tests
This commit is contained in:
parent
69e28b92c6
commit
0865e52db3
14 changed files with 479 additions and 284 deletions
35
.github/workflows/test-litellm.yml
vendored
Normal file
35
.github/workflows/test-litellm.yml
vendored
Normal file
|
@ -0,0 +1,35 @@
|
|||
name: LiteLLM Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Thank You Message
|
||||
run: |
|
||||
echo "### 🙏 Thank you for contributing to LiteLLM!" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Your PR is being tested now. We appreciate your help in making LiteLLM better!" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
poetry install --with dev,proxy-dev --extras proxy
|
||||
poetry run pip install pytest-xdist
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
poetry run pytest tests/litellm -x -vv -n 4
|
3
Makefile
3
Makefile
|
@ -14,6 +14,9 @@ help:
|
|||
install-dev:
|
||||
poetry install --with dev
|
||||
|
||||
install-proxy-dev:
|
||||
poetry install --with dev,proxy-dev
|
||||
|
||||
lint: install-dev
|
||||
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
|
||||
cd litellm && poetry run mypy . --ignore-missing-imports
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -22,7 +22,7 @@ class XAIModelInfo(BaseLLMModelInfo):
|
|||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
api_base = self.get_api_base(api_base)
|
||||
api_key = self.get_api_key(api_key)
|
||||
if api_base is None or api_key is None:
|
||||
|
|
|
@ -462,6 +462,8 @@ async def proxy_startup_event(app: FastAPI):
|
|||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
## CHECK MASTER KEY IN ENVIRONMENT ##
|
||||
master_key = get_secret_str("LITELLM_MASTER_KEY")
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
|
|
444
poetry.lock
generated
444
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -99,6 +99,11 @@ mypy = "^1.0"
|
|||
pytest = "^7.4.3"
|
||||
pytest-mock = "^3.12.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
respx = "^0.20.2"
|
||||
|
||||
[tool.poetry.group.proxy-dev.dependencies]
|
||||
prisma = "0.11.0"
|
||||
hypercorn = "^0.15.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core", "wheel"]
|
||||
|
|
63
tests/litellm/conftest.py
Normal file
63
tests/litellm/conftest.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# conftest.py
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
|
||||
try:
|
||||
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
|
||||
import litellm.proxy.proxy_server
|
||||
|
||||
importlib.reload(litellm.proxy.proxy_server)
|
||||
except Exception as e:
|
||||
print(f"Error reloading litellm.proxy.proxy_server: {e}")
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
# Teardown code (executes after the yield point)
|
||||
loop.close() # Close the loop created earlier
|
||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [
|
||||
item for item in items if "custom_logger" in item.parent.name
|
||||
]
|
||||
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
|
@ -19,6 +19,11 @@ from litellm.types.llms.openai import AllMessageValues
|
|||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_anthropic_api_key(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")
|
||||
|
||||
|
||||
class TestCustomPromptManagement(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
|
@ -50,7 +55,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_prompt_management_with_prompt_id():
|
||||
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
|
||||
custom_prompt_management = TestCustomPromptManagement()
|
||||
litellm.callbacks = [custom_prompt_management]
|
||||
|
||||
|
|
|
@ -26,6 +26,11 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def add_anthropic_api_key_to_env(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_view_spend_logs_with_user_id(client, monkeypatch):
|
||||
# Mock data for the test
|
||||
|
@ -500,7 +505,7 @@ class TestSpendLogsPayload:
|
|||
return mock_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_logs_payload_success_log_with_api_base(self):
|
||||
async def test_spend_logs_payload_success_log_with_api_base(self, monkeypatch):
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
litellm.callbacks = [_ProxyDBLogger(message_logging=False)]
|
||||
|
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
|||
import click
|
||||
import httpx
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
@ -74,3 +75,92 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
|||
call[0] for call in mock_proxy_config.get_credentials.mock_calls
|
||||
]
|
||||
assert len(mock_scheduler_calls) > 0
|
||||
|
||||
|
||||
# Mock Prisma
|
||||
class MockPrisma:
|
||||
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
|
||||
self.database_url = database_url
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
self.http_client = http_client
|
||||
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
mock_prisma = MockPrisma()
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
|
||||
"""
|
||||
Test that master_key is correctly loaded from either config.yaml or environment variables
|
||||
"""
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Import happens here - this is when the module probably reads the config path
|
||||
from litellm.proxy.proxy_server import proxy_startup_event
|
||||
|
||||
# Mock the Prisma import
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
|
||||
|
||||
# Create test app
|
||||
app = FastAPI()
|
||||
|
||||
# Test Case 1: Master key from config.yaml
|
||||
test_master_key = "sk-12345"
|
||||
test_config = {"general_settings": {"master_key": test_master_key}}
|
||||
|
||||
# Create a temporary config file
|
||||
config_path = tmp_path / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(test_config, f)
|
||||
|
||||
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
|
||||
# Second setting of CONFIG_FILE_PATH to a different value
|
||||
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
|
||||
print(f"config_path: {config_path}")
|
||||
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
|
||||
async with proxy_startup_event(app):
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
assert master_key == test_master_key
|
||||
|
||||
# Test Case 2: Master key from environment variable
|
||||
test_env_master_key = "sk-67890"
|
||||
|
||||
# Create empty config
|
||||
empty_config = {"general_settings": {}}
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(empty_config, f)
|
||||
|
||||
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
|
||||
print("test_env_master_key: {}".format(test_env_master_key))
|
||||
async with proxy_startup_event(app):
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
assert master_key == test_env_master_key
|
||||
|
||||
# Test Case 3: Master key with os.environ prefix
|
||||
test_resolved_key = "sk-resolved-key"
|
||||
test_config_with_prefix = {
|
||||
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
|
||||
}
|
||||
|
||||
# Create config with os.environ prefix
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(test_config_with_prefix, f)
|
||||
|
||||
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
|
||||
async with proxy_startup_event(app):
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
assert master_key == test_resolved_key
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm import rerank
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
||||
mock_response = MagicMock()
|
||||
|
||||
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
||||
args = {
|
||||
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
||||
"query": "hello",
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"results": [
|
||||
{"index": 0, "relevanceScore": 0.6716859340667725},
|
||||
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
||||
]
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
rerank(
|
||||
model=args["model"],
|
||||
query=args["query"],
|
||||
documents=args["documents"],
|
||||
client=client,
|
||||
)
|
||||
mock_post.assert_called_once()
|
||||
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
||||
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
||||
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
|
@ -14,6 +14,15 @@ from unittest.mock import MagicMock, patch
|
|||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def add_api_keys_to_env(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-1234567890")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-api03-1234567890")
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "my-fake-aws-access-key-id")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "my-fake-aws-secret-access-key")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_api_response():
|
||||
mock_response_data = {
|
||||
|
@ -130,7 +139,8 @@ def test_completion_missing_role(openai_api_response):
|
|||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_url_with_format_param(model, sync_mode):
|
||||
async def test_url_with_format_param(model, sync_mode, monkeypatch):
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ litellm.num_retries = 3
|
|||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_cohere_citations(stream):
|
||||
try:
|
||||
|
|
|
@ -481,3 +481,42 @@ def test_rerank_cohere_api():
|
|||
assert response.results[0]["document"]["text"] is not None
|
||||
assert response.results[0]["document"]["text"] == "hello"
|
||||
assert response.results[1]["document"]["text"] == "world"
|
||||
|
||||
|
||||
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
||||
|
||||
mock_response = MagicMock()
|
||||
|
||||
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
||||
args = {
|
||||
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
||||
"query": "hello",
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"results": [
|
||||
{"index": 0, "relevanceScore": 0.6716859340667725},
|
||||
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
||||
]
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
litellm.rerank(
|
||||
model=args["model"],
|
||||
query=args["query"],
|
||||
documents=args["documents"],
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
||||
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
||||
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue