fix cost tracking for vertex ai native

This commit is contained in:
Ishaan Jaff 2024-08-31 08:22:27 -07:00
parent 06857d108d
commit b35bfb0302
3 changed files with 30 additions and 10 deletions

View file

@ -21,4 +21,4 @@ router_settings:
general_settings:
master_key: sk-1234
custom_auth: example_config_yaml.custom_auth.user_api_key_auth
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth

View file

@ -693,10 +693,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking():
global prisma_client, custom_db_client
if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list):
if isinstance(litellm._async_success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
async def _PROXY_failure_handler(

View file

@ -66,12 +66,24 @@ async def call_spend_logs_endpoint():
import requests
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
url = f"http://0.0.0.0:4000/global/spend/logs"
url = f"http://0.0.0.0:4000/global/spend/logs?api_key=best-api-key-ever"
headers = {"Authorization": f"Bearer sk-1234"}
response = requests.get(url, headers=headers)
print("response from call_spend_logs_endpoint", response)
return response
json_response = response.json()
# get spend for today
"""
json response looks like this
[{'date': '2024-08-30', 'spend': 0.00016600000000000002, 'api_key': 'best-api-key-ever'}]
"""
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
for spend_log in json_response:
if spend_log["date"] == todays_date:
return spend_log["spend"]
LITE_LLM_ENDPOINT = "http://localhost:4000"
@ -79,7 +91,10 @@ LITE_LLM_ENDPOINT = "http://localhost:4000"
@pytest.mark.asyncio()
async def test_basic_vertex_ai_pass_through_with_spendlog():
load_vertex_ai_credentials()
spend_before = await call_spend_logs_endpoint() or 0.0
# load_vertex_ai_credentials()
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
@ -92,8 +107,13 @@ async def test_basic_vertex_ai_pass_through_with_spendlog():
print("response", response)
await asyncio.sleep(3)
_spend_logs_response = await call_spend_logs_endpoint()
print("spend logs response", _spend_logs_response)
await asyncio.sleep(20)
spend_after = await call_spend_logs_endpoint()
print("spend_after", spend_after)
assert (
spend_after > spend_before
), "Spend should be greater than before. spend_before: {}, spend_after: {}".format(
spend_before, spend_after
)
pass