forked from phoenix/litellm-mirror
Merge pull request #1761 from BerriAI/litellm_fix_dynamic_callbacks
fix(utils.py): override default success callbacks with dynamic callbacks if set
This commit is contained in:
commit
93fb0134e5
9 changed files with 115 additions and 64 deletions
|
@ -10,7 +10,7 @@ from fastapi import HTTPException
|
||||||
import json, traceback
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
class CacheControlCheck(CustomLogger):
|
class _PROXY_CacheControlCheck(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -7,7 +7,7 @@ from fastapi import HTTPException
|
||||||
import json, traceback
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
class MaxBudgetLimiter(CustomLogger):
|
class _PROXY_MaxBudgetLimiter(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -9,7 +9,7 @@ from litellm import ModelResponse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class MaxParallelRequestsHandler(CustomLogger):
|
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
user_api_key_cache = None
|
user_api_key_cache = None
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
|
|
@ -566,7 +566,6 @@ async def user_api_key_auth(
|
||||||
and (not general_settings.get("allow_user_auth", False))
|
and (not general_settings.get("allow_user_auth", False))
|
||||||
):
|
):
|
||||||
# enters this block when allow_user_auth is set to False
|
# enters this block when allow_user_auth is set to False
|
||||||
assert not general_settings.get("allow_user_auth", False)
|
|
||||||
if route == "/key/info":
|
if route == "/key/info":
|
||||||
# check if user can access this route
|
# check if user can access this route
|
||||||
query_params = request.query_params
|
query_params = request.query_params
|
||||||
|
@ -679,16 +678,17 @@ def cost_tracking():
|
||||||
if prisma_client is not None or custom_db_client is not None:
|
if prisma_client is not None or custom_db_client is not None:
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
litellm.success_callback.append(track_cost_callback) # type: ignore
|
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def track_cost_callback(
|
async def _PROXY_track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response: litellm.ModelResponse, # response from completion
|
completion_response: litellm.ModelResponse, # response from completion
|
||||||
start_time=None,
|
start_time=None,
|
||||||
end_time=None, # start/end time for completion
|
end_time=None, # start/end time for completion
|
||||||
):
|
):
|
||||||
|
verbose_proxy_logger.debug(f"INSIDE _PROXY_track_cost_callback")
|
||||||
global prisma_client, custom_db_client
|
global prisma_client, custom_db_client
|
||||||
try:
|
try:
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
|
@ -752,8 +752,8 @@ async def update_database(
|
||||||
end_time=None,
|
end_time=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.info(
|
||||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
### UPDATE USER SPEND ###
|
### UPDATE USER SPEND ###
|
||||||
|
@ -865,18 +865,16 @@ async def update_database(
|
||||||
)
|
)
|
||||||
|
|
||||||
payload["spend"] = response_cost
|
payload["spend"] = response_cost
|
||||||
|
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
await custom_db_client.insert_data(payload, table_name="spend")
|
await custom_db_client.insert_data(payload, table_name="spend")
|
||||||
|
|
||||||
tasks = []
|
asyncio.create_task(_update_user_db())
|
||||||
tasks.append(_update_user_db())
|
asyncio.create_task(_update_key_db())
|
||||||
tasks.append(_update_key_db())
|
asyncio.create_task(_insert_spend_log_to_db())
|
||||||
tasks.append(_insert_spend_log_to_db())
|
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Error updating Prisma database: {traceback.format_exc()}"
|
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||||
|
@ -3934,7 +3932,7 @@ def _has_user_setup_sso():
|
||||||
"""
|
"""
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
ui_username = os.getenv("UI_USERNAME", None)
|
||||||
|
|
||||||
sso_setup = (
|
sso_setup = (
|
||||||
(microsoft_client_id is not None)
|
(microsoft_client_id is not None)
|
||||||
|
|
|
@ -8,9 +8,11 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_SpendLogs,
|
LiteLLM_SpendLogs,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
_PROXY_MaxParallelRequestsHandler,
|
||||||
from litellm.proxy.hooks.cache_control_check import CacheControlCheck
|
)
|
||||||
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -41,9 +43,9 @@ class ProxyLogging:
|
||||||
## INITIALIZE LITELLM CALLBACKS ##
|
## INITIALIZE LITELLM CALLBACKS ##
|
||||||
self.call_details: dict = {}
|
self.call_details: dict = {}
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
||||||
self.max_budget_limiter = MaxBudgetLimiter()
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||||||
self.cache_control_check = CacheControlCheck()
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
pass
|
pass
|
||||||
|
@ -522,7 +524,6 @@ class PrismaClient:
|
||||||
response = await self.db.litellm_verificationtoken.find_many(
|
response = await self.db.litellm_verificationtoken.find_many(
|
||||||
order={"spend": "desc"},
|
order={"spend": "desc"},
|
||||||
)
|
)
|
||||||
print_verbose(f"PrismaClient: response={response}")
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
|
@ -1200,8 +1201,6 @@ async def reset_budget(prisma_client: PrismaClient):
|
||||||
table_name="user", query_type="find_all", reset_at=now
|
table_name="user", query_type="find_all", reset_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}")
|
|
||||||
|
|
||||||
if users_to_reset is not None and len(users_to_reset) > 0:
|
if users_to_reset is not None and len(users_to_reset) > 0:
|
||||||
for user in users_to_reset:
|
for user in users_to_reset:
|
||||||
user.spend = 0.0
|
user.spend = 0.0
|
||||||
|
|
|
@ -213,7 +213,9 @@ def test_call_with_user_over_budget(custom_db_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -244,7 +246,7 @@ def test_call_with_user_over_budget(custom_db_client):
|
||||||
},
|
},
|
||||||
completion_response=resp,
|
completion_response=resp,
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -285,7 +287,9 @@ def test_call_with_user_over_budget_stream(custom_db_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -317,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
|
||||||
},
|
},
|
||||||
completion_response=ModelResponse(),
|
completion_response=ModelResponse(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -357,7 +361,9 @@ def test_call_with_user_key_budget(custom_db_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -388,7 +394,7 @@ def test_call_with_user_key_budget(custom_db_client):
|
||||||
},
|
},
|
||||||
completion_response=resp,
|
completion_response=resp,
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -429,7 +435,9 @@ def test_call_with_key_over_budget_stream(custom_db_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -461,7 +469,7 @@ def test_call_with_key_over_budget_stream(custom_db_client):
|
||||||
},
|
},
|
||||||
completion_response=ModelResponse(),
|
completion_response=ModelResponse(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
|
@ -242,7 +242,9 @@ def test_call_with_user_over_budget(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -275,7 +277,7 @@ def test_call_with_user_over_budget(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -323,7 +325,9 @@ def test_call_with_proxy_over_budget(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -356,7 +360,7 @@ def test_call_with_proxy_over_budget(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -401,7 +405,9 @@ def test_call_with_user_over_budget_stream(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -435,7 +441,7 @@ def test_call_with_user_over_budget_stream(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -488,7 +494,9 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
resp = ModelResponse(
|
resp = ModelResponse(
|
||||||
|
@ -522,7 +530,7 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -891,7 +899,9 @@ def test_call_with_key_over_budget(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
|
|
||||||
|
@ -931,7 +941,7 @@ def test_call_with_key_over_budget(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(4)
|
||||||
# test spend_log was written and we can read it
|
# test spend_log was written and we can read it
|
||||||
spend_logs = await view_spend_logs(request_id=request_id)
|
spend_logs = await view_spend_logs(request_id=request_id)
|
||||||
|
|
||||||
|
@ -955,6 +965,8 @@ def test_call_with_key_over_budget(prisma_client):
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# print(f"Error - {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
error_detail = e.message
|
error_detail = e.message
|
||||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
|
@ -983,7 +995,9 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -1022,7 +1036,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
@ -1058,7 +1072,9 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import track_cost_callback
|
from litellm.proxy.proxy_server import (
|
||||||
|
_PROXY_track_cost_callback as track_cost_callback,
|
||||||
|
)
|
||||||
from litellm import ModelResponse, Choices, Message, Usage
|
from litellm import ModelResponse, Choices, Message, Usage
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -1096,7 +1112,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
|
@ -18,7 +18,9 @@ from litellm import Router
|
||||||
from litellm.proxy.utils import ProxyLogging
|
from litellm.proxy.utils import ProxyLogging
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
|
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
|
||||||
|
)
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
## On Request received
|
## On Request received
|
||||||
|
|
|
@ -774,14 +774,14 @@ class Logging:
|
||||||
self.streaming_chunks = [] # for generating complete stream response
|
self.streaming_chunks = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks = [] # for generating complete stream response
|
self.sync_streaming_chunks = [] # for generating complete stream response
|
||||||
self.model_call_details = {}
|
self.model_call_details = {}
|
||||||
self.dynamic_input_callbacks = [] # callbacks set for just that call
|
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
|
||||||
self.dynamic_failure_callbacks = [] # callbacks set for just that call
|
self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call
|
||||||
self.dynamic_success_callbacks = (
|
self.dynamic_success_callbacks = (
|
||||||
dynamic_success_callbacks or []
|
dynamic_success_callbacks # callbacks set for just that call
|
||||||
) # callbacks set for just that call
|
)
|
||||||
self.dynamic_async_success_callbacks = (
|
self.dynamic_async_success_callbacks = (
|
||||||
dynamic_async_success_callbacks or []
|
dynamic_async_success_callbacks # callbacks set for just that call
|
||||||
) # callbacks set for just that call
|
)
|
||||||
## DYNAMIC LANGFUSE KEYS ##
|
## DYNAMIC LANGFUSE KEYS ##
|
||||||
self.langfuse_public_key = langfuse_public_key
|
self.langfuse_public_key = langfuse_public_key
|
||||||
self.langfuse_secret = langfuse_secret
|
self.langfuse_secret = langfuse_secret
|
||||||
|
@ -1145,7 +1145,19 @@ class Logging:
|
||||||
f"Model={self.model} not found in completion cost map."
|
f"Model={self.model} not found in completion cost map."
|
||||||
)
|
)
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
callbacks = litellm.success_callback + self.dynamic_success_callbacks
|
if self.dynamic_success_callbacks is not None and isinstance(
|
||||||
|
self.dynamic_success_callbacks, list
|
||||||
|
):
|
||||||
|
callbacks = self.dynamic_success_callbacks
|
||||||
|
## keep the internal functions ##
|
||||||
|
for callback in litellm.success_callback:
|
||||||
|
if (
|
||||||
|
isinstance(callback, CustomLogger)
|
||||||
|
and "_PROXY_" in callback.__class__.__name__
|
||||||
|
):
|
||||||
|
callbacks.append(callback)
|
||||||
|
else:
|
||||||
|
callbacks = litellm.success_callback
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
|
@ -1406,9 +1418,6 @@ class Logging:
|
||||||
"""
|
"""
|
||||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
"""
|
"""
|
||||||
verbose_logger.debug(
|
|
||||||
f"Async success callbacks: {litellm._async_success_callback}"
|
|
||||||
)
|
|
||||||
start_time, end_time, result = self._success_handler_helper_fn(
|
start_time, end_time, result = self._success_handler_helper_fn(
|
||||||
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
|
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
|
||||||
)
|
)
|
||||||
|
@ -1452,9 +1461,22 @@ class Logging:
|
||||||
)
|
)
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
callbacks = (
|
if self.dynamic_async_success_callbacks is not None and isinstance(
|
||||||
litellm._async_success_callback + self.dynamic_async_success_callbacks
|
self.dynamic_async_success_callbacks, list
|
||||||
)
|
):
|
||||||
|
callbacks = self.dynamic_async_success_callbacks
|
||||||
|
## keep the internal functions ##
|
||||||
|
for callback in litellm._async_success_callback:
|
||||||
|
callback_name = ""
|
||||||
|
if isinstance(callback, CustomLogger):
|
||||||
|
callback_name = callback.__class__.__name__
|
||||||
|
if callable(callback):
|
||||||
|
callback_name = callback.__name__
|
||||||
|
if "_PROXY_" in callback_name:
|
||||||
|
callbacks.append(callback)
|
||||||
|
else:
|
||||||
|
callbacks = litellm._async_success_callback
|
||||||
|
verbose_logger.debug(f"Async success callbacks: {callbacks}")
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "cache" and litellm.cache is not None:
|
if callback == "cache" and litellm.cache is not None:
|
||||||
|
@ -1501,6 +1523,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
if callable(callback): # custom logger functions
|
if callable(callback): # custom logger functions
|
||||||
|
print_verbose(f"Making async function logging call")
|
||||||
if self.stream:
|
if self.stream:
|
||||||
if "complete_streaming_response" in self.model_call_details:
|
if "complete_streaming_response" in self.model_call_details:
|
||||||
await customLogger.async_log_event(
|
await customLogger.async_log_event(
|
||||||
|
@ -1958,8 +1981,8 @@ def client(original_function):
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
litellm.failure_callback.pop(index)
|
litellm.failure_callback.pop(index)
|
||||||
### DYNAMIC CALLBACKS ###
|
### DYNAMIC CALLBACKS ###
|
||||||
dynamic_success_callbacks = []
|
dynamic_success_callbacks = None
|
||||||
dynamic_async_success_callbacks = []
|
dynamic_async_success_callbacks = None
|
||||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||||
kwargs["success_callback"], list
|
kwargs["success_callback"], list
|
||||||
):
|
):
|
||||||
|
@ -1970,7 +1993,12 @@ def client(original_function):
|
||||||
or callback == "dynamodb"
|
or callback == "dynamodb"
|
||||||
or callback == "s3"
|
or callback == "s3"
|
||||||
):
|
):
|
||||||
dynamic_async_success_callbacks.append(callback)
|
if dynamic_async_success_callbacks is not None and isinstance(
|
||||||
|
dynamic_async_success_callbacks, list
|
||||||
|
):
|
||||||
|
dynamic_async_success_callbacks.append(callback)
|
||||||
|
else:
|
||||||
|
dynamic_async_success_callbacks = [callback]
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue