refactor location of proxy

This commit is contained in:
Ishaan Jaff 2025-04-23 14:38:44 -07:00
parent baa5564f95
commit ce58c53ff1
413 changed files with 2087 additions and 2088 deletions

View file

@ -3,7 +3,7 @@ import os
import sys
from typing import Any, Dict, Optional, List
from unittest.mock import Mock
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
from litellm_proxy.utils import _get_redoc_url, _get_docs_url
import json
import pytest
from fastapi import Request
@ -14,9 +14,9 @@ sys.path.insert(
import litellm
from unittest.mock import MagicMock, patch, AsyncMock
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.litellm_pre_call_utils import (
from litellm_proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm_proxy.auth.auth_utils import is_request_body_safe
from litellm_proxy.litellm_pre_call_utils import (
_get_dynamic_logging_metadata,
add_litellm_data_to_request,
)
@ -29,7 +29,7 @@ def mock_request(monkeypatch):
mock_request.query_params = {} # Set mock query_params to an empty dictionary
mock_request.headers = {"traceparent": "test_traceparent"}
monkeypatch.setattr(
"litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", mock_request
"litellm_proxy.litellm_pre_call_utils.add_litellm_data_to_request", mock_request
)
return mock_request
@ -92,7 +92,7 @@ async def test_traceparent_not_added_by_default(endpoint, mock_request):
from litellm.integrations.opentelemetry import OpenTelemetry
otel_logger = OpenTelemetry()
setattr(litellm.proxy.proxy_server, "open_telemetry_logger", otel_logger)
setattr(litellm_proxy.proxy_server, "open_telemetry_logger", otel_logger)
mock_request.url.path = endpoint
user_api_key_dict = UserAPIKeyAuth(
@ -110,7 +110,7 @@ async def test_traceparent_not_added_by_default(endpoint, mock_request):
_extra_headers = data.get("extra_headers") or {}
assert "traceparent" not in _extra_headers
setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None)
setattr(litellm_proxy.proxy_server, "open_telemetry_logger", None)
@pytest.mark.parametrize(
@ -232,7 +232,7 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
user_api_key_dict = UserAPIKeyAuth(
@ -314,7 +314,7 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
],
)
def test_dynamic_turn_off_message_logging(callback_vars):
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
user_api_key_dict = UserAPIKeyAuth(
@ -460,7 +460,7 @@ def test_is_request_body_safe_model_enabled(
def test_reading_openai_org_id_from_headers():
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm_proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
headers = {
"OpenAI-Organization": "test_org_id",
@ -488,8 +488,8 @@ def test_reading_openai_org_id_from_headers():
)
def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
import json
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm_proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm_proxy._types import UserAPIKeyAuth
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
@ -509,8 +509,8 @@ def test_foward_litellm_user_info_to_backend_llm_call():
litellm.add_user_information_to_llm_headers = True
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm_proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm_proxy._types import UserAPIKeyAuth
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
@ -531,10 +531,10 @@ def test_foward_litellm_user_info_to_backend_llm_call():
def test_update_internal_user_params():
from litellm.proxy.management_endpoints.internal_user_endpoints import (
from litellm_proxy.management_endpoints.internal_user_endpoints import (
_update_internal_new_user_params,
)
from litellm.proxy._types import NewUserRequest
from litellm_proxy._types import NewUserRequest
litellm.default_internal_user_params = {
"max_budget": 100,
@ -558,7 +558,7 @@ def test_update_internal_user_params():
@pytest.mark.asyncio
async def test_proxy_config_update_from_db():
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.proxy_server import ProxyConfig
from pydantic import BaseModel
proxy_config = ProxyConfig()
@ -602,10 +602,10 @@ async def test_proxy_config_update_from_db():
def test_prepare_key_update_data():
from litellm.proxy.management_endpoints.key_management_endpoints import (
from litellm_proxy.management_endpoints.key_management_endpoints import (
prepare_key_update_data,
)
from litellm.proxy._types import UpdateKeyRequest
from litellm_proxy._types import UpdateKeyRequest
existing_key_row = MagicMock()
data = UpdateKeyRequest(key="test_key", models=["gpt-4"], duration="120s")
@ -691,7 +691,7 @@ def test_get_docs_url(env_vars, expected_url):
],
)
def test_merge_tags(request_tags, tags_to_add, expected_tags):
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm_proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
result = LiteLLMProxyRequestSetup._merge_tags(
request_tags=request_tags, tags_to_add=tags_to_add
@ -855,7 +855,7 @@ async def test_add_litellm_data_to_request_duplicate_tags(
def test_enforced_params_check(
general_settings, user_api_key_dict, request_body, expected_error
):
from litellm.proxy.litellm_pre_call_utils import _enforced_params_check
from litellm_proxy.litellm_pre_call_utils import _enforced_params_check
if expected_error:
with pytest.raises(ValueError):
@ -875,7 +875,7 @@ def test_enforced_params_check(
def test_get_key_models():
from litellm.proxy.auth.model_checks import get_key_models
from litellm_proxy.auth.model_checks import get_key_models
from collections import defaultdict
user_api_key_dict = UserAPIKeyAuth(
@ -899,7 +899,7 @@ def test_get_key_models():
def test_get_team_models():
from litellm.proxy.auth.model_checks import get_team_models
from litellm_proxy.auth.model_checks import get_team_models
from collections import defaultdict
user_api_key_dict = UserAPIKeyAuth(
@ -925,7 +925,7 @@ def test_get_team_models():
def test_update_config_fields():
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
@ -979,7 +979,7 @@ def test_get_complete_model_list(proxy_model_list, provider):
"""
Test that get_complete_model_list correctly expands model groups like 'openai/*' into individual models with provider prefixes
"""
from litellm.proxy.auth.model_checks import get_complete_model_list
from litellm_proxy.auth.model_checks import get_complete_model_list
complete_list = get_complete_model_list(
proxy_model_list=proxy_model_list,
@ -999,7 +999,7 @@ def test_get_complete_model_list(proxy_model_list, provider):
def test_team_callback_metadata_all_none_values():
from litellm.proxy._types import TeamCallbackMetadata
from litellm_proxy._types import TeamCallbackMetadata
resp = TeamCallbackMetadata(
success_callback=None,
@ -1021,7 +1021,7 @@ def test_team_callback_metadata_all_none_values():
],
)
def test_team_callback_metadata_none_values(none_key):
from litellm.proxy._types import TeamCallbackMetadata
from litellm_proxy._types import TeamCallbackMetadata
if none_key == "success_callback":
args = {
@ -1055,8 +1055,8 @@ def test_proxy_config_state_post_init_callback_call():
Where team_id was being popped from config, after callback was called
"""
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm_proxy.proxy_server import ProxyConfig
pc = ProxyConfig()
@ -1088,7 +1088,7 @@ def test_proxy_config_state_get_config_state_error():
"""
Ensures that get_config_state does not raise an error when the config is not a valid dictionary
"""
from litellm.proxy.proxy_server import ProxyConfig
from litellm_proxy.proxy_server import ProxyConfig
import threading
test_config = {
@ -1142,7 +1142,7 @@ def test_litellm_verification_token_view_response_with_budget_table(
expected_user_api_key_auth_key,
expected_user_api_key_auth_value,
):
from litellm.proxy._types import LiteLLM_VerificationTokenView
from litellm_proxy._types import LiteLLM_VerificationTokenView
args: Dict[str, Any] = {
"token": "78b627d4d14bc3acf5571ae9cb6834e661bc8794d1209318677387add7621ce1",
@ -1194,8 +1194,8 @@ def test_litellm_verification_token_view_response_with_budget_table(
def test_is_allowed_to_make_key_request():
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.management_endpoints.key_management_endpoints import (
from litellm_proxy._types import LitellmUserRoles
from litellm_proxy.management_endpoints.key_management_endpoints import (
_is_allowed_to_make_key_request,
)
@ -1225,7 +1225,7 @@ def test_is_allowed_to_make_key_request():
def test_get_model_group_info():
from litellm.proxy.proxy_server import _get_model_group_info
from litellm_proxy.proxy_server import _get_model_group_info
from litellm import Router
router = Router(
@ -1310,14 +1310,14 @@ class MockPrismaClientDB:
@pytest.mark.asyncio
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
# Patch the prisma_client import
from litellm.proxy._types import UserInfoResponse
from litellm_proxy._types import UserInfoResponse
with patch(
"litellm.proxy.proxy_server.prisma_client",
"litellm_proxy.proxy_server.prisma_client",
MockPrismaClientDB(mock_team_data, mock_key_data),
):
from litellm.proxy.management_endpoints.internal_user_endpoints import (
from litellm_proxy.management_endpoints.internal_user_endpoints import (
_get_user_info_for_proxy_admin,
)
@ -1330,9 +1330,9 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
def test_custom_openid_response():
from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor
from litellm.proxy.management_endpoints.ui_sso import JWTHandler
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm_proxy.management_endpoints.ui_sso import generic_response_convertor
from litellm_proxy.management_endpoints.ui_sso import JWTHandler
from litellm_proxy._types import LiteLLM_JWTAuth
from litellm.caching import DualCache
jwt_handler = JWTHandler()
@ -1365,7 +1365,7 @@ def test_update_key_request_validation():
"""
Ensures that the UpdateKeyRequest model validates the temp_budget_increase and temp_budget_expiry fields together
"""
from litellm.proxy._types import UpdateKeyRequest
from litellm_proxy._types import UpdateKeyRequest
with pytest.raises(Exception):
UpdateKeyRequest(
@ -1387,8 +1387,8 @@ def test_update_key_request_validation():
def test_get_temp_budget_increase():
from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase
from litellm.proxy._types import UserAPIKeyAuth
from litellm_proxy.auth.user_api_key_auth import _get_temp_budget_increase
from litellm_proxy._types import UserAPIKeyAuth
from datetime import datetime, timedelta
expiry = datetime.now() + timedelta(days=1)
@ -1406,10 +1406,10 @@ def test_get_temp_budget_increase():
def test_update_key_budget_with_temp_budget_increase():
from litellm.proxy.auth.user_api_key_auth import (
from litellm_proxy.auth.user_api_key_auth import (
_update_key_budget_with_temp_budget_increase,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm_proxy._types import UserAPIKeyAuth
from datetime import datetime, timedelta
expiry = datetime.now() + timedelta(days=1)
@ -1431,7 +1431,7 @@ from unittest.mock import MagicMock, AsyncMock
@pytest.mark.asyncio
async def test_health_check_not_called_when_disabled(monkeypatch):
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm_proxy.proxy_server import ProxyStartupEvent
# Mock environment variable
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
@ -1444,7 +1444,7 @@ async def test_health_check_not_called_when_disabled(monkeypatch):
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
# Mock PrismaClient constructor
monkeypatch.setattr(
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
"litellm_proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
)
# Call the setup function
@ -1459,7 +1459,7 @@ async def test_health_check_not_called_when_disabled(monkeypatch):
@patch(
"litellm.proxy.proxy_server.get_openapi_schema",
"litellm_proxy.proxy_server.get_openapi_schema",
return_value={
"paths": {
"/new/route": {"get": {"summary": "New"}},
@ -1467,8 +1467,8 @@ async def test_health_check_not_called_when_disabled(monkeypatch):
},
)
def test_custom_openapi(mock_get_openapi_schema):
from litellm.proxy.proxy_server import custom_openapi
from litellm.proxy.proxy_server import app
from litellm_proxy.proxy_server import custom_openapi
from litellm_proxy.proxy_server import app
openapi_schema = custom_openapi()
assert openapi_schema is not None
@ -1478,7 +1478,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
import asyncio
from datetime import timedelta
from litellm.proxy.utils import ProxyUpdateSpend
from litellm_proxy.utils import ProxyUpdateSpend
@pytest.mark.asyncio
@ -1529,7 +1529,7 @@ async def test_spend_logs_cleanup_after_error():
def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import (
from litellm_proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request,
)
@ -1593,7 +1593,7 @@ def test_provider_specific_header():
}
from litellm.proxy._types import LiteLLM_UserTable
from litellm_proxy._types import LiteLLM_UserTable
@pytest.mark.parametrize(
@ -1610,7 +1610,7 @@ from litellm.proxy._types import LiteLLM_UserTable
],
)
def test_get_known_models_from_wildcard(wildcard_model, expected_models):
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard
from litellm_proxy.auth.model_checks import get_known_models_from_wildcard
wildcard_models = get_known_models_from_wildcard(wildcard_model=wildcard_model)
# Check if all expected models are in the returned list
@ -1658,7 +1658,7 @@ def test_get_known_models_from_wildcard(wildcard_model, expected_models):
],
)
def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model):
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
from litellm_proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
# Make a copy of the input data to avoid modifying the test parameters
test_data = data.copy()
@ -1767,7 +1767,7 @@ async def test_get_admin_team_ids(
should_query_db: bool,
mock_prisma_client,
):
from litellm.proxy.management_endpoints.key_management_endpoints import (
from litellm_proxy.management_endpoints.key_management_endpoints import (
get_admin_team_ids,
)