From 08db9fdc01d449a1b32e7110147ef93fe25eb12e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 10:40:18 -0800 Subject: [PATCH] add e2e test for vertex pass through with spend tags --- tests/pass_through_tests/test_vertex_ai.py | 74 ++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/tests/pass_through_tests/test_vertex_ai.py b/tests/pass_through_tests/test_vertex_ai.py index dee0d59eb..0b78043f1 100644 --- a/tests/pass_through_tests/test_vertex_ai.py +++ b/tests/pass_through_tests/test_vertex_ai.py @@ -99,7 +99,7 @@ async def test_basic_vertex_ai_pass_through_with_spendlog(): vertexai.init( project="adroit-crow-413218", location="us-central1", - api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", + api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", api_transport="rest", ) @@ -120,8 +120,74 @@ async def test_basic_vertex_ai_pass_through_with_spendlog(): pass +@pytest.mark.asyncio +async def test_vertex_ai_direct_api_with_tags(): + """ + e2e test that tags are added to the spend log + + This is how vertex JS SDK interacts with the pass through endpoint + + - Vertex JS SDK, Auth is sent with `x-litellm-api-key` header (JS SDK uses `Authorization` header, so need to send litellm api key as `x-litellm-api-key` header) + - Tags are sent with `tags` header + """ + import requests + import json + + url = "http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent" + + headers = { + "Content-Type": "application/json", + "x-litellm-api-key": "sk-1234", + "tags": "vertex-js-sdk,pass-through-endpoint", + } + + payload = { + "contents": [ + {"role": "user", "parts": [{"text": "Say 'hello test' and nothing else"}]} + ] + } + + # Make the request + response = requests.post(url, headers=headers, json=payload) + assert ( + response.status_code == 200 + ), f"Expected 200 status code, got {response.status_code}" + + # Get the litellm call ID from response headers + litellm_call_id = response.headers.get("x-litellm-call-id") + print(f"LiteLLM Call ID: {litellm_call_id}") + + # Wait for spend to be logged + await asyncio.sleep(15) + + # Check spend logs for this specific request + spend_response = requests.get( + f"http://0.0.0.0:4000/spend/logs?request_id={litellm_call_id}", + headers={"Authorization": "Bearer sk-1234"}, + ) + + spend_data = spend_response.json() + print(f"Spend data: {spend_data}") + assert spend_data is not None, "Should have spend data for the request" + + log_entry = spend_data[0] # Get the first log entry + + # Verify the response and metadata + assert log_entry["request_id"] == litellm_call_id, "Request ID should match" + assert ( + log_entry["call_type"] == "pass_through_endpoint" + ), "Call type should be pass_through_endpoint" + assert log_entry["request_tags"] == [ + "vertex-js-sdk", + "pass-through-endpoint", + ], "Tags should match input" + assert ( + "user_api_key" in log_entry["metadata"] + ), "Should have user API key in metadata" + assert "gemini" in log_entry["model"], "Model should be gemini" + + @pytest.mark.asyncio() -@pytest.mark.skip(reason="skip flaky test - vertex pass through streaming is flaky") async def test_basic_vertex_ai_pass_through_streaming_with_spendlog(): spend_before = await call_spend_logs_endpoint() or 0.0 @@ -131,7 +197,7 @@ async def test_basic_vertex_ai_pass_through_streaming_with_spendlog(): vertexai.init( project="adroit-crow-413218", location="us-central1", - api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", + api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", api_transport="rest", ) @@ -170,7 +236,7 @@ async def test_vertex_ai_pass_through_endpoint_context_caching(): vertexai.init( project="adroit-crow-413218", location="us-central1", - api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", + api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", api_transport="rest", )