diff --git a/.circleci/config.yml b/.circleci/config.yml index db7c4ef5b..2f9df4d52 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1137,6 +1137,15 @@ jobs: pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" pip install anthropic + python -m pip install -r requirements.txt + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . @@ -1172,6 +1181,26 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Run tests command: | @@ -1179,7 +1208,6 @@ jobs: ls python -m pytest -vv tests/pass_through_tests/ -x --junitxml=test-results/junit.xml --durations=5 no_output_timeout: 120m - # Store test results - store_test_results: path: test-results diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index 07b0beb75..03190c839 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -12,6 +12,71 @@ Looking for the Unified API (OpenAI format) for VertexAI ? [Go here - using vert ::: +Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation). + +Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex-ai` + + +#### **Example Usage** + + + + +```bash +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "contents":[{ + "role": "user", + "parts":[{"text": "How are you doing today?"}] + }] + }' +``` + + + + +```javascript +const { VertexAI } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'your-project-id', // enter your vertex project id + location: 'us-central1', // enter your vertex region + apiEndpoint: "localhost:4000/vertex-ai" // /vertex-ai # note, do not include 'https://' in the url +}); + +const model = vertexAI.getGenerativeModel({ + model: 'gemini-1.0-pro' +}, { + customHeaders: { + "x-litellm-api-key": "sk-1234" // Your litellm Virtual Key + } +}); + +async function generateContent() { + try { + const prompt = { + contents: [{ + role: 'user', + parts: [{ text: 'How are you doing today?' }] + }] + }; + + const response = await model.generateContent(prompt); + console.log('Response:', response); + } catch (error) { + console.error('Error:', error); + } +} + +generateContent(); +``` + + + + + ## Supported API Endpoints - Gemini API diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 35cff0db3..ad5a98258 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -100,7 +100,7 @@ class AnthropicPassthroughLoggingHandler: kwargs["response_cost"] = response_cost kwargs["model"] = model - # Make standard logging object for Vertex AI + # Make standard logging object for Anthropic standard_logging_object = get_standard_logging_object_payload( kwargs=kwargs, init_response_obj=litellm_model_response, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index fe61f32ee..275a0a119 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -56,8 +56,14 @@ class VertexPassthroughLoggingHandler: encoding=None, ) ) - logging_obj.model = litellm_model_response.model or model - logging_obj.model_call_details["model"] = logging_obj.model + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) await logging_obj.async_success_handler( result=litellm_model_response, @@ -147,6 +153,14 @@ class VertexPassthroughLoggingHandler: "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) return + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) await litellm_logging_obj.async_success_handler( result=complete_streaming_response, start_time=start_time, @@ -193,3 +207,47 @@ class VertexPassthroughLoggingHandler: if match: return match.group(1) return "unknown" + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ + litellm.ModelResponse, litellm.TextCompletionResponse + ], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + return kwargs diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index baf107a16..f60fd0166 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -36,7 +36,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str -from .streaming_handler import chunk_processor +from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging from .types import EndpointType, PassthroughStandardLoggingPayload @@ -448,7 +448,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, @@ -491,7 +491,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9ba5adfec..522319aaa 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -27,93 +27,107 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -async def chunk_processor( - response: httpx.Response, - request_body: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - endpoint_type: EndpointType, - start_time: datetime, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, -): - """ - - Yields chunks from the response - - Collect non-empty chunks for post-processing (logging) - """ - collected_chunks: List[str] = [] # List to store all chunks - try: - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue +class PassThroughStreamingHandler: - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] + async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) + yield chunk - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" + # After all chunks are processed, handle post-processing + end_time = datetime.now() - # After all chunks are processed, handle post-processing - end_time = datetime.now() + await PassThroughStreamingHandler._route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise - await _route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=collected_chunks, - end_time=end_time, + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler + + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes ) + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.VERTEX_AI: + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass - except Exception as e: - verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") - raise + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + Args: + raw_bytes: List of bytes chunks from aiter.bytes() -async def _route_streaming_logging_to_handler( - litellm_logging_obj: LiteLLMLoggingObj, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, - request_body: dict, - endpoint_type: EndpointType, - start_time: datetime, - all_chunks: List[str], - end_time: datetime, -): - """ - Route the logging for the collected chunks to the appropriate handler + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") - Supported endpoint types: - - Anthropic - - Vertex AI - """ - if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 2bd5b790c..fbf37ce8d 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -119,7 +119,6 @@ async def vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): encoded_endpoint = httpx.URL(endpoint).path @@ -127,6 +126,11 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} + api_key_to_use = get_litellm_virtual_key(request=request) + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) vertex_project = None vertex_location = None @@ -214,3 +218,18 @@ async def vertex_proxy_route( ) return received_value + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/tests/anthropic_passthrough/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy similarity index 100% rename from tests/anthropic_passthrough/test_anthropic_passthrough.py rename to tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy diff --git a/tests/pass_through_tests/test_gemini.js b/tests/pass_through_tests/test_gemini.js new file mode 100644 index 000000000..2b7d6c5c6 --- /dev/null +++ b/tests/pass_through_tests/test_gemini.js @@ -0,0 +1,23 @@ +// const { GoogleGenerativeAI } = require("@google/generative-ai"); + +// const genAI = new GoogleGenerativeAI("sk-1234"); +// const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); + +// const prompt = "Explain how AI works in 2 pages"; + +// async function run() { +// try { +// const result = await model.generateContentStream(prompt, { baseUrl: "http://localhost:4000/gemini" }); +// const response = await result.response; +// console.log(response.text()); +// for await (const chunk of result.stream) { +// const chunkText = chunk.text(); +// console.log(chunkText); +// process.stdout.write(chunkText); +// } +// } catch (error) { +// console.error("Error:", error); +// } +// } + +// run(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js new file mode 100644 index 000000000..7ae9b942a --- /dev/null +++ b/tests/pass_through_tests/test_local_vertex.js @@ -0,0 +1,68 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); + + +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: new Headers({ + "x-litellm-api-key": "sk-1234" + }) +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function streamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); + } +} + + +async function nonStreamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const response = await generativeModel.generateContent(request); + console.log('non streaming response: ', JSON.stringify(response)); + } catch (error) { + console.error('Error:', error); + } +} + + + +streamingResponse(); +nonStreamingResponse(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js new file mode 100644 index 000000000..dc457c68a --- /dev/null +++ b/tests/pass_through_tests/test_vertex.test.js @@ -0,0 +1,114 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); +const fs = require('fs'); +const path = require('path'); +const os = require('os'); +const { writeFileSync } = require('fs'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + +function loadVertexAiCredentials() { + console.log("loading vertex ai credentials"); + const filepath = path.dirname(__filename); + const vertexKeyPath = path.join(filepath, "vertex_key.json"); + + // Initialize default empty service account data + let serviceAccountKeyData = {}; + + // Try to read existing vertex_key.json + try { + const content = fs.readFileSync(vertexKeyPath, 'utf8'); + if (content && content.trim()) { + serviceAccountKeyData = JSON.parse(content); + } + } catch (error) { + // File doesn't exist or is invalid, continue with empty object + } + + // Update with environment variables + const privateKeyId = process.env.VERTEX_AI_PRIVATE_KEY_ID || ""; + const privateKey = (process.env.VERTEX_AI_PRIVATE_KEY || "").replace(/\\n/g, "\n"); + + serviceAccountKeyData.private_key_id = privateKeyId; + serviceAccountKeyData.private_key = privateKey; + + // Create temporary file + const tempFilePath = path.join(os.tmpdir(), `vertex-credentials-${Date.now()}.json`); + writeFileSync(tempFilePath, JSON.stringify(serviceAccountKeyData, null, 2)); + + // Set environment variable + process.env.GOOGLE_APPLICATION_CREDENTIALS = tempFilePath; +} + +// Run credential loading before tests +beforeAll(() => { + loadVertexAiCredentials(); +}); + + + +describe('Vertex AI Tests', () => { + test('should successfully generate content from Vertex AI', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" + }); + + const customHeaders = new Headers({ + "x-litellm-api-key": "sk-1234" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + + const streamingResult = await generativeModel.generateContentStream(request); + + // Add some assertions + expect(streamingResult).toBeDefined(); + + for await (const item of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(item)); + expect(item).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + }); + + + test('should successfully generate non-streaming content from Vertex AI', async () => { + const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); + const customHeaders = new Headers({"x-litellm-api-key": "sk-1234"}); + const requestOptions = {customHeaders: customHeaders}; + const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); + const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; + + const result = await generativeModel.generateContent(request); + expect(result).toBeDefined(); + expect(result.response).toBeDefined(); + console.log('non-streaming response:', JSON.stringify(result.response)); + }); +}); \ No newline at end of file diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py similarity index 100% rename from tests/pass_through_unit_tests/test_unit_test_anthropic.py rename to tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py new file mode 100644 index 000000000..bbbc465fc --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -0,0 +1,118 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import httpx +import pytest +import litellm +from typing import AsyncGenerator +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy.pass_through_endpoints.types import EndpointType +from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, +) +from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, +) + + +# Helper function to mock async iteration +async def aiter_mock(iterable): + for item in iterable: + yield item + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_type,url_route", + [ + ( + EndpointType.VERTEX_AI, + "v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent", + ), + (EndpointType.ANTHROPIC, "/v1/messages"), + ], +) +async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): + """ + Test that the chunk_processor yields raw bytes + + This is CRITICAL for pass throughs streaming with Vertex AI and Anthropic + """ + # Mock inputs + response = AsyncMock(spec=httpx.Response) + raw_chunks = [ + b'{"id": "1", "content": "Hello"}', + b'{"id": "2", "content": "World"}', + b'\n\ndata: {"id": "3"}', # Testing different byte formats + ] + + # Mock aiter_bytes to return an async generator + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + response.aiter_bytes = mock_aiter_bytes + + request_body = {"key": "value"} + litellm_logging_obj = MagicMock() + start_time = datetime.now() + passthrough_success_handler_obj = MagicMock() + + # Capture yielded chunks and perform detailed assertions + received_chunks = [] + async for chunk in PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + ): + # Assert each chunk is bytes + assert isinstance(chunk, bytes), f"Chunk should be bytes, got {type(chunk)}" + # Assert no decoding/encoding occurred (chunk should be exactly as input) + assert ( + chunk in raw_chunks + ), f"Chunk {chunk} was modified during processing. For pass throughs streaming, chunks should be raw bytes" + received_chunks.append(chunk) + + # Assert all chunks were processed + assert len(received_chunks) == len(raw_chunks), "Not all chunks were processed" + + # collected chunks all together + assert b"".join(received_chunks) == b"".join( + raw_chunks + ), "Collected chunks do not match raw chunks" + + +def test_convert_raw_bytes_to_str_lines(): + """ + Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings + """ + # Test case 1: Single chunk + raw_bytes = [b'data: {"content": "Hello"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}'] + + # Test case 2: Multiple chunks + raw_bytes = [b'data: {"content": "Hello"}\n', b'data: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] + + # Test case 3: Empty input + raw_bytes = [] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == [] + + # Test case 4: Chunks with empty lines + raw_bytes = [b'data: {"content": "Hello"}\n\n', b'\ndata: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py new file mode 100644 index 000000000..a7b668813 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -0,0 +1,84 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + get_litellm_virtual_key, + vertex_proxy_route, +) + + +@pytest.mark.asyncio +async def test_get_litellm_virtual_key(): + """ + Test that the get_litellm_virtual_key function correctly handles the API key authentication + """ + # Test with x-litellm-api-key + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + # Test with Authorization header + mock_request.headers = {"Authorization": "Bearer auth-key-456"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer auth-key-456" + + # Test with both headers (x-litellm-api-key should take precedence) + mock_request.headers = { + "x-litellm-api-key": "test-key-123", + "Authorization": "Bearer auth-key-456", + } + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_vertex_proxy_route_api_key_auth(): + """ + Critical + + This is how Vertex AI JS SDK will Auth to Litellm Proxy + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123"