From 8f43d8d43678d77113ad82bec3adbe996b1cd10f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 11:09:47 -0800 Subject: [PATCH] add e2e tests for streaming vertex JS with tags --- tests/pass_through_tests/test_vertex_ai.py | 67 ------ .../test_vertex_with_spend.test.js | 194 ++++++++++++++++++ 2 files changed, 194 insertions(+), 67 deletions(-) create mode 100644 tests/pass_through_tests/test_vertex_with_spend.test.js diff --git a/tests/pass_through_tests/test_vertex_ai.py b/tests/pass_through_tests/test_vertex_ai.py index 0b78043f1..23faf6b82 100644 --- a/tests/pass_through_tests/test_vertex_ai.py +++ b/tests/pass_through_tests/test_vertex_ai.py @@ -120,73 +120,6 @@ async def test_basic_vertex_ai_pass_through_with_spendlog(): pass -@pytest.mark.asyncio -async def test_vertex_ai_direct_api_with_tags(): - """ - e2e test that tags are added to the spend log - - This is how vertex JS SDK interacts with the pass through endpoint - - - Vertex JS SDK, Auth is sent with `x-litellm-api-key` header (JS SDK uses `Authorization` header, so need to send litellm api key as `x-litellm-api-key` header) - - Tags are sent with `tags` header - """ - import requests - import json - - url = "http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent" - - headers = { - "Content-Type": "application/json", - "x-litellm-api-key": "sk-1234", - "tags": "vertex-js-sdk,pass-through-endpoint", - } - - payload = { - "contents": [ - {"role": "user", "parts": [{"text": "Say 'hello test' and nothing else"}]} - ] - } - - # Make the request - response = requests.post(url, headers=headers, json=payload) - assert ( - response.status_code == 200 - ), f"Expected 200 status code, got {response.status_code}" - - # Get the litellm call ID from response headers - litellm_call_id = response.headers.get("x-litellm-call-id") - print(f"LiteLLM Call ID: {litellm_call_id}") - - # Wait for spend to be logged - await asyncio.sleep(15) - - # Check spend logs for this specific request - spend_response = requests.get( - f"http://0.0.0.0:4000/spend/logs?request_id={litellm_call_id}", - headers={"Authorization": "Bearer sk-1234"}, - ) - - spend_data = spend_response.json() - print(f"Spend data: {spend_data}") - assert spend_data is not None, "Should have spend data for the request" - - log_entry = spend_data[0] # Get the first log entry - - # Verify the response and metadata - assert log_entry["request_id"] == litellm_call_id, "Request ID should match" - assert ( - log_entry["call_type"] == "pass_through_endpoint" - ), "Call type should be pass_through_endpoint" - assert log_entry["request_tags"] == [ - "vertex-js-sdk", - "pass-through-endpoint", - ], "Tags should match input" - assert ( - "user_api_key" in log_entry["metadata"] - ), "Should have user API key in metadata" - assert "gemini" in log_entry["model"], "Model should be gemini" - - @pytest.mark.asyncio() async def test_basic_vertex_ai_pass_through_streaming_with_spendlog(): diff --git a/tests/pass_through_tests/test_vertex_with_spend.test.js b/tests/pass_through_tests/test_vertex_with_spend.test.js new file mode 100644 index 000000000..8a5b91557 --- /dev/null +++ b/tests/pass_through_tests/test_vertex_with_spend.test.js @@ -0,0 +1,194 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); +const fs = require('fs'); +const path = require('path'); +const os = require('os'); +const { writeFileSync } = require('fs'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +let lastCallId; + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://127.0.0.1:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + + const response = await originalFetch(url, options); + + // Store the call ID if it exists + lastCallId = response.headers.get('x-litellm-call-id'); + + return response; +}; + +function loadVertexAiCredentials() { + console.log("loading vertex ai credentials"); + const filepath = path.dirname(__filename); + const vertexKeyPath = path.join(filepath, "vertex_key.json"); + + // Initialize default empty service account data + let serviceAccountKeyData = {}; + + // Try to read existing vertex_key.json + try { + const content = fs.readFileSync(vertexKeyPath, 'utf8'); + if (content && content.trim()) { + serviceAccountKeyData = JSON.parse(content); + } + } catch (error) { + // File doesn't exist or is invalid, continue with empty object + } + + // Update with environment variables + const privateKeyId = process.env.VERTEX_AI_PRIVATE_KEY_ID || ""; + const privateKey = (process.env.VERTEX_AI_PRIVATE_KEY || "").replace(/\\n/g, "\n"); + + serviceAccountKeyData.private_key_id = privateKeyId; + serviceAccountKeyData.private_key = privateKey; + + // Create temporary file + const tempFilePath = path.join(os.tmpdir(), `vertex-credentials-${Date.now()}.json`); + writeFileSync(tempFilePath, JSON.stringify(serviceAccountKeyData, null, 2)); + + // Set environment variable + process.env.GOOGLE_APPLICATION_CREDENTIALS = tempFilePath; +} + +// Run credential loading before tests +// beforeAll(() => { +// loadVertexAiCredentials(); +// }); + + + +describe('Vertex AI Tests', () => { + test('should successfully generate non-streaming content with tags', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "127.0.0.1:4000/vertex_ai" + }); + + const customHeaders = new Headers({ + "x-litellm-api-key": "sk-1234", + "tags": "vertex-js-sdk,pass-through-endpoint" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'Say "hello test" and nothing else'}]}] + }; + + const result = await generativeModel.generateContent(request); + expect(result).toBeDefined(); + + // Use the captured callId + const callId = lastCallId; + console.log("Captured Call ID:", callId); + + // Wait for spend to be logged + await new Promise(resolve => setTimeout(resolve, 15000)); + + // Check spend logs + const spendResponse = await fetch( + `http://127.0.0.1:4000/spend/logs?request_id=${callId}`, + { + headers: { + 'Authorization': 'Bearer sk-1234' + } + } + ); + + const spendData = await spendResponse.json(); + console.log("spendData", spendData) + expect(spendData).toBeDefined(); + expect(spendData[0].request_id).toBe(callId); + expect(spendData[0].call_type).toBe('pass_through_endpoint'); + expect(spendData[0].request_tags).toEqual(['vertex-js-sdk', 'pass-through-endpoint']); + expect(spendData[0].metadata).toHaveProperty('user_api_key'); + expect(spendData[0].model).toContain('gemini'); + expect(spendData[0].spend).toBeGreaterThan(0); + }, 25000); + + test('should successfully generate streaming content with tags', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "127.0.0.1:4000/vertex_ai" + }); + + const customHeaders = new Headers({ + "x-litellm-api-key": "sk-1234", + "tags": "vertex-js-sdk,pass-through-endpoint" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'Say "hello test" and nothing else'}]}] + }; + + const streamingResult = await generativeModel.generateContentStream(request); + expect(streamingResult).toBeDefined(); + + + // Add some assertions + expect(streamingResult).toBeDefined(); + + for await (const item of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(item)); + expect(item).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + + // Use the captured callId + const callId = lastCallId; + console.log("Captured Call ID:", callId); + + // Wait for spend to be logged + await new Promise(resolve => setTimeout(resolve, 15000)); + + // Check spend logs + const spendResponse = await fetch( + `http://127.0.0.1:4000/spend/logs?request_id=${callId}`, + { + headers: { + 'Authorization': 'Bearer sk-1234' + } + } + ); + + const spendData = await spendResponse.json(); + console.log("spendData", spendData) + expect(spendData).toBeDefined(); + expect(spendData[0].request_id).toBe(callId); + expect(spendData[0].call_type).toBe('pass_through_endpoint'); + expect(spendData[0].request_tags).toEqual(['vertex-js-sdk', 'pass-through-endpoint']); + expect(spendData[0].metadata).toHaveProperty('user_api_key'); + expect(spendData[0].model).toContain('gemini'); + expect(spendData[0].spend).toBeGreaterThan(0); + }, 25000); +}); \ No newline at end of file