diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b30fced3d..f7a55dd70 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -141,6 +141,14 @@ import json import logging from typing import Union +# import enterprise folder +try: + # when using litellm cli + import litellm.proxy.enterprise as enterprise +except: + # when using litellm docker image + import enterprise # type: ignore + ui_link = f"/ui/" ui_message = ( f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO" @@ -1626,7 +1634,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "llamaguard_moderations" ): - from litellm.proxy.enterprise.enterprise_hooks.llama_guard import ( + from enterprise.enterprise_hooks.llama_guard import ( _ENTERPRISE_LlamaGuard, ) @@ -1636,7 +1644,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "google_text_moderation" ): - from litellm.proxy.enterprise.enterprise_hooks.google_text_moderation import ( + from enterprise.enterprise_hooks.google_text_moderation import ( _ENTERPRISE_GoogleTextModeration, ) @@ -1648,7 +1656,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "llmguard_moderations" ): - from litellm.proxy.enterprise.enterprise_hooks.llm_guard import ( + from enterprise.enterprise_hooks.llm_guard import ( _ENTERPRISE_LLMGuard, ) @@ -1658,7 +1666,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "blocked_user_check" ): - from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( + from enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, ) @@ -1668,7 +1676,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "banned_keywords" ): - from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import ( + from enterprise.enterprise_hooks.banned_keywords import ( _ENTERPRISE_BannedKeywords, ) @@ -4119,7 +4127,7 @@ async def view_spend_tags( ``` """ - from litellm.proxy.enterprise.utils import get_spend_by_tags + from enterprise.utils import get_spend_by_tags global prisma_client try: @@ -4528,12 +4536,8 @@ async def global_spend_models( dependencies=[Depends(user_api_key_auth)], ) async def global_predict_spend_logs(request: Request): - try: - # when using litellm package - from litellm.proxy.enterprise.utils import _forecast_daily_cost - except: - # when using litellm docker image - from enterprise.utils import _forecast_daily_cost + from enterprise.utils import _forecast_daily_cost + data = await request.json() data = data.get("data") return _forecast_daily_cost(data) @@ -4991,7 +4995,7 @@ async def block_user(data: BlockUsers): }' ``` """ - from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( + from enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, ) @@ -5032,7 +5036,7 @@ async def unblock_user(data: BlockUsers): }' ``` """ - from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( + from enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, ) diff --git a/tests/test_spend_logs.py b/tests/test_spend_logs.py index 4d7ad175f..c6866317d 100644 --- a/tests/test_spend_logs.py +++ b/tests/test_spend_logs.py @@ -113,6 +113,46 @@ async def test_spend_logs(): await get_spend_logs(session=session, request_id=response["id"]) +async def get_predict_spend_logs(session): + url = f"http://0.0.0.0:4000/global/predict/spend/logs" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "data": [ + { + "date": "2024-03-09", + "spend": 200000, + "api_key": "f19bdeb945164278fc11c1020d8dfd70465bffd931ed3cb2e1efa6326225b8b7", + } + ] + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_get_predicted_spend_logs(): + """ + - Create key + - Make call (makes sure it's in spend logs) + - Get request id from logs + """ + async with aiohttp.ClientSession() as session: + result = await get_predict_spend_logs(session=session) + print(result) + + assert "response" in result + assert len(result["response"]) > 0 + + @pytest.mark.skip(reason="High traffic load test, meant to be run locally") @pytest.mark.asyncio async def test_spend_logs_high_traffic():