refactor location of proxy

This commit is contained in:
Ishaan Jaff 2025-04-23 14:38:44 -07:00
parent baa5564f95
commit ce58c53ff1
413 changed files with 2087 additions and 2088 deletions

View file

@ -15,7 +15,7 @@ from unittest.mock import Mock
import httpx
from litellm.proxy.proxy_server import initialize_pass_through_endpoints
from litellm_proxy.proxy_server import initialize_pass_through_endpoints
# Mock the async_client used in the pass_through_request function
@ -36,7 +36,7 @@ def remove_rerank_route(app):
@pytest.fixture
def client():
from litellm.proxy.proxy_server import app
from litellm_proxy.proxy_server import app
remove_rerank_route(
app=app
@ -61,10 +61,10 @@ async def test_pass_through_endpoint_no_headers(client, monkeypatch):
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: dict = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
# Make a request to the pass-through endpoint
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
@ -92,10 +92,10 @@ async def test_pass_through_endpoint(client, monkeypatch):
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
# Make a request to the pass-through endpoint
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
@ -122,10 +122,10 @@ async def test_pass_through_endpoint_rerank(client):
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"model": "rerank-english-v3.0",
@ -154,8 +154,8 @@ async def test_pass_through_endpoint_rpm_limit(
client, auth, expected_error_code, rpm_limit
):
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
from litellm_proxy._types import UserAPIKeyAuth
from litellm_proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
@ -167,10 +167,10 @@ async def test_pass_through_endpoint_rpm_limit(
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
setattr(litellm_proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm_proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm_proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm_proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
pass_through_endpoints = [
@ -185,10 +185,10 @@ async def test_pass_through_endpoint_rpm_limit(
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"model": "rerank-english-v3.0",
@ -220,22 +220,22 @@ async def test_pass_through_endpoint_rpm_limit(
async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
auth, expected_error_code, rpm_limit
):
from litellm.proxy.proxy_server import app
from litellm_proxy.proxy_server import app
client = TestClient(app)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
from litellm_proxy._types import UserAPIKeyAuth
from litellm_proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
# Store original values
original_user_api_key_cache = getattr(
litellm.proxy.proxy_server, "user_api_key_cache", None
litellm_proxy.proxy_server, "user_api_key_cache", None
)
original_master_key = getattr(litellm.proxy.proxy_server, "master_key", None)
original_prisma_client = getattr(litellm.proxy.proxy_server, "prisma_client", None)
original_master_key = getattr(litellm_proxy.proxy_server, "master_key", None)
original_prisma_client = getattr(litellm_proxy.proxy_server, "prisma_client", None)
original_proxy_logging_obj = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj", None
litellm_proxy.proxy_server, "proxy_logging_obj", None
)
try:
@ -252,10 +252,10 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
setattr(litellm_proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm_proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm_proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm_proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
pass_through_endpoints = [
@ -274,11 +274,11 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
old_general_settings = general_settings
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"batch": [
@ -316,18 +316,18 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
# Assert the response
assert response.status_code == expected_error_code
setattr(litellm.proxy.proxy_server, "general_settings", old_general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", old_general_settings)
finally:
# Reset to original values
setattr(
litellm.proxy.proxy_server,
litellm_proxy.proxy_server,
"user_api_key_cache",
original_user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "master_key", original_master_key)
setattr(litellm.proxy.proxy_server, "prisma_client", original_prisma_client)
setattr(litellm_proxy.proxy_server, "master_key", original_master_key)
setattr(litellm_proxy.proxy_server, "prisma_client", original_prisma_client)
setattr(
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
litellm_proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
)
@pytest.mark.asyncio
@ -378,10 +378,10 @@ async def test_pass_through_endpoint_bing(client, monkeypatch):
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
getattr(litellm_proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm_proxy.proxy_server, "general_settings", general_settings)
# Make 2 requests thru the pass-through endpoint
client.get("/bing/search?q=bob+barker")