From b35bfb0302c9b85127362469c93822b12062a11f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 31 Aug 2024 08:22:27 -0700 Subject: [PATCH] fix cost tracking for vertex ai native --- litellm/proxy/proxy_config.yaml | 2 +- litellm/proxy/proxy_server.py | 6 ++-- tests/pass_through_tests/test_vertex_ai.py | 32 ++++++++++++++++++---- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 708f0e27c..f9252a4d5 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -21,4 +21,4 @@ router_settings: general_settings: master_key: sk-1234 - custom_auth: example_config_yaml.custom_auth.user_api_key_auth \ No newline at end of file + custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e0dbcebfc..80e015e2a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -693,10 +693,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): def cost_tracking(): global prisma_client, custom_db_client if prisma_client is not None or custom_db_client is not None: - if isinstance(litellm.success_callback, list): + if isinstance(litellm._async_success_callback, list): verbose_proxy_logger.debug("setting litellm success callback to track cost") - if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore - litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore + if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore + litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore async def _PROXY_failure_handler( diff --git a/tests/pass_through_tests/test_vertex_ai.py b/tests/pass_through_tests/test_vertex_ai.py index 675557098..542105eb4 100644 --- a/tests/pass_through_tests/test_vertex_ai.py +++ b/tests/pass_through_tests/test_vertex_ai.py @@ -66,12 +66,24 @@ async def call_spend_logs_endpoint(): import requests todays_date = datetime.datetime.now().strftime("%Y-%m-%d") - url = f"http://0.0.0.0:4000/global/spend/logs" + url = f"http://0.0.0.0:4000/global/spend/logs?api_key=best-api-key-ever" headers = {"Authorization": f"Bearer sk-1234"} response = requests.get(url, headers=headers) print("response from call_spend_logs_endpoint", response) - return response + json_response = response.json() + + # get spend for today + """ + json response looks like this + + [{'date': '2024-08-30', 'spend': 0.00016600000000000002, 'api_key': 'best-api-key-ever'}] + """ + + todays_date = datetime.datetime.now().strftime("%Y-%m-%d") + for spend_log in json_response: + if spend_log["date"] == todays_date: + return spend_log["spend"] LITE_LLM_ENDPOINT = "http://localhost:4000" @@ -79,7 +91,10 @@ LITE_LLM_ENDPOINT = "http://localhost:4000" @pytest.mark.asyncio() async def test_basic_vertex_ai_pass_through_with_spendlog(): - load_vertex_ai_credentials() + + spend_before = await call_spend_logs_endpoint() or 0.0 + # load_vertex_ai_credentials() + vertexai.init( project="adroit-crow-413218", location="us-central1", @@ -92,8 +107,13 @@ async def test_basic_vertex_ai_pass_through_with_spendlog(): print("response", response) - await asyncio.sleep(3) - _spend_logs_response = await call_spend_logs_endpoint() - print("spend logs response", _spend_logs_response) + await asyncio.sleep(20) + spend_after = await call_spend_logs_endpoint() + print("spend_after", spend_after) + assert ( + spend_after > spend_before + ), "Spend should be greater than before. spend_before: {}, spend_after: {}".format( + spend_before, spend_after + ) pass