forked from phoenix/litellm-mirror
Merge pull request #2501 from BerriAI/litellm_fix_using_enterprise_docker
(fix) using enterprise folder on litellm Docker
This commit is contained in:
commit
8a886c6e93
2 changed files with 58 additions and 14 deletions
|
@ -141,6 +141,14 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
# import enterprise folder
|
||||||
|
try:
|
||||||
|
# when using litellm cli
|
||||||
|
import litellm.proxy.enterprise as enterprise
|
||||||
|
except:
|
||||||
|
# when using litellm docker image
|
||||||
|
import enterprise # type: ignore
|
||||||
|
|
||||||
ui_link = f"/ui/"
|
ui_link = f"/ui/"
|
||||||
ui_message = (
|
ui_message = (
|
||||||
f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
|
f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
|
||||||
|
@ -1626,7 +1634,7 @@ class ProxyConfig:
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
and callback == "llamaguard_moderations"
|
and callback == "llamaguard_moderations"
|
||||||
):
|
):
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.llama_guard import (
|
from enterprise.enterprise_hooks.llama_guard import (
|
||||||
_ENTERPRISE_LlamaGuard,
|
_ENTERPRISE_LlamaGuard,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1636,7 +1644,7 @@ class ProxyConfig:
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
and callback == "google_text_moderation"
|
and callback == "google_text_moderation"
|
||||||
):
|
):
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.google_text_moderation import (
|
from enterprise.enterprise_hooks.google_text_moderation import (
|
||||||
_ENTERPRISE_GoogleTextModeration,
|
_ENTERPRISE_GoogleTextModeration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1648,7 +1656,7 @@ class ProxyConfig:
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
and callback == "llmguard_moderations"
|
and callback == "llmguard_moderations"
|
||||||
):
|
):
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.llm_guard import (
|
from enterprise.enterprise_hooks.llm_guard import (
|
||||||
_ENTERPRISE_LLMGuard,
|
_ENTERPRISE_LLMGuard,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1658,7 +1666,7 @@ class ProxyConfig:
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
and callback == "blocked_user_check"
|
and callback == "blocked_user_check"
|
||||||
):
|
):
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
_ENTERPRISE_BlockedUserList,
|
_ENTERPRISE_BlockedUserList,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1668,7 +1676,7 @@ class ProxyConfig:
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
and callback == "banned_keywords"
|
and callback == "banned_keywords"
|
||||||
):
|
):
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
|
from enterprise.enterprise_hooks.banned_keywords import (
|
||||||
_ENTERPRISE_BannedKeywords,
|
_ENTERPRISE_BannedKeywords,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4119,7 +4127,7 @@ async def view_spend_tags(
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from litellm.proxy.enterprise.utils import get_spend_by_tags
|
from enterprise.utils import get_spend_by_tags
|
||||||
|
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
|
@ -4528,12 +4536,8 @@ async def global_spend_models(
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def global_predict_spend_logs(request: Request):
|
async def global_predict_spend_logs(request: Request):
|
||||||
try:
|
from enterprise.utils import _forecast_daily_cost
|
||||||
# when using litellm package
|
|
||||||
from litellm.proxy.enterprise.utils import _forecast_daily_cost
|
|
||||||
except:
|
|
||||||
# when using litellm docker image
|
|
||||||
from enterprise.utils import _forecast_daily_cost
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
data = data.get("data")
|
data = data.get("data")
|
||||||
return _forecast_daily_cost(data)
|
return _forecast_daily_cost(data)
|
||||||
|
@ -4991,7 +4995,7 @@ async def block_user(data: BlockUsers):
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
_ENTERPRISE_BlockedUserList,
|
_ENTERPRISE_BlockedUserList,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5032,7 +5036,7 @@ async def unblock_user(data: BlockUsers):
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
_ENTERPRISE_BlockedUserList,
|
_ENTERPRISE_BlockedUserList,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -113,6 +113,46 @@ async def test_spend_logs():
|
||||||
await get_spend_logs(session=session, request_id=response["id"])
|
await get_spend_logs(session=session, request_id=response["id"])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_predict_spend_logs(session):
|
||||||
|
url = f"http://0.0.0.0:4000/global/predict/spend/logs"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"date": "2024-03-09",
|
||||||
|
"spend": 200000,
|
||||||
|
"api_key": "f19bdeb945164278fc11c1020d8dfd70465bffd931ed3cb2e1efa6326225b8b7",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_predicted_spend_logs():
|
||||||
|
"""
|
||||||
|
- Create key
|
||||||
|
- Make call (makes sure it's in spend logs)
|
||||||
|
- Get request id from logs
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
result = await get_predict_spend_logs(session=session)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
assert "response" in result
|
||||||
|
assert len(result["response"]) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="High traffic load test, meant to be run locally")
|
@pytest.mark.skip(reason="High traffic load test, meant to be run locally")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_spend_logs_high_traffic():
|
async def test_spend_logs_high_traffic():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue