From f83708ed4eb9d3e6939e6fcbdac0df9f3750b31b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 12:59:01 -0800 Subject: [PATCH 01/21] stash gemini JS test --- tests/pass_through_tests/test_gemini.js | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/pass_through_tests/test_gemini.js diff --git a/tests/pass_through_tests/test_gemini.js b/tests/pass_through_tests/test_gemini.js new file mode 100644 index 000000000..2b7d6c5c6 --- /dev/null +++ b/tests/pass_through_tests/test_gemini.js @@ -0,0 +1,23 @@ +// const { GoogleGenerativeAI } = require("@google/generative-ai"); + +// const genAI = new GoogleGenerativeAI("sk-1234"); +// const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); + +// const prompt = "Explain how AI works in 2 pages"; + +// async function run() { +// try { +// const result = await model.generateContentStream(prompt, { baseUrl: "http://localhost:4000/gemini" }); +// const response = await result.response; +// console.log(response.text()); +// for await (const chunk of result.stream) { +// const chunkText = chunk.text(); +// console.log(chunkText); +// process.stdout.write(chunkText); +// } +// } catch (error) { +// console.error("Error:", error); +// } +// } + +// run(); \ No newline at end of file From dcab2d0c6f57e566fbee4027a83c94d8af2dfd09 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 12:59:34 -0800 Subject: [PATCH 02/21] add vertex js sdj example --- tests/pass_through_tests/test_vertex.js | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/pass_through_tests/test_vertex.js diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js new file mode 100644 index 000000000..39db0ad0d --- /dev/null +++ b/tests/pass_through_tests/test_vertex.js @@ -0,0 +1,43 @@ + +const { + FunctionDeclarationSchemaType, + HarmBlockThreshold, + HarmCategory, + VertexAI, + RequestOptions + } = require('@google-cloud/vertexai'); + + const project = 'adroit-crow-413218'; + const location = 'us-central1'; + const textModel = 'gemini-1.0-pro'; + const visionModel = 'gemini-1.0-pro-vision'; + + + const vertexAI = new VertexAI({project: project, location: location, apiEndpoint: "localhost:4000/vertex-ai"}); + + // Instantiate Gemini models + const generativeModel = vertexAI.getGenerativeModel({ + model: textModel, + // The following parameters are optional + // They can also be passed to individual content generation requests + safetySettings: [{category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + generationConfig: {maxOutputTokens: 256}, + systemInstruction: { + role: 'system', + parts: [{"text": `For example, you are a helpful customer service agent. tell me your name. in 5 pages`}] + }, + }) + +async function streamGenerateContent() { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + }; + + streamGenerateContent(); \ No newline at end of file From e829b228b24c266481b03b96b063f43e96b1d132 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:18:08 -0800 Subject: [PATCH 03/21] handle vertex pass through separately --- .../streaming_handler.py | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9ba5adfec..a88ad34d3 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -40,35 +40,39 @@ async def chunk_processor( - Yields chunks from the response - Collect non-empty chunks for post-processing (logging) """ - collected_chunks: List[str] = [] # List to store all chunks try: - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue + if endpoint_type == EndpointType.VERTEX_AI: + async for chunk in response.aiter_bytes(): + yield chunk + else: + collected_chunks: List[str] = [] # List to store all chunks + async for chunk in response.aiter_lines(): + verbose_proxy_logger.debug(f"Processing chunk: {chunk}") + if not chunk: + continue - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") + # Handle SSE format - pass through the raw SSE format + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" + # Store the chunk for post-processing + if chunk.strip(): # Only store non-empty chunks + collected_chunks.append(chunk) + yield f"{chunk}\n" - # After all chunks are processed, handle post-processing - end_time = datetime.now() + # After all chunks are processed, handle post-processing + end_time = datetime.now() - await _route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=collected_chunks, - end_time=end_time, - ) + await _route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=collected_chunks, + end_time=end_time, + ) except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") From bbb2e029b562d7924c7186146c59d121b72436b0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:18:23 -0800 Subject: [PATCH 04/21] tes vertex JS sdk --- tests/pass_through_tests/test_vertex.js | 75 ++++++++++++------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js index 39db0ad0d..1f2eaea33 100644 --- a/tests/pass_through_tests/test_vertex.js +++ b/tests/pass_through_tests/test_vertex.js @@ -1,43 +1,40 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); -const { - FunctionDeclarationSchemaType, - HarmBlockThreshold, - HarmCategory, - VertexAI, - RequestOptions - } = require('@google-cloud/vertexai'); - - const project = 'adroit-crow-413218'; - const location = 'us-central1'; - const textModel = 'gemini-1.0-pro'; - const visionModel = 'gemini-1.0-pro-vision'; +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); - - const vertexAI = new VertexAI({project: project, location: location, apiEndpoint: "localhost:4000/vertex-ai"}); - - // Instantiate Gemini models - const generativeModel = vertexAI.getGenerativeModel({ - model: textModel, - // The following parameters are optional - // They can also be passed to individual content generation requests - safetySettings: [{category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], - generationConfig: {maxOutputTokens: 256}, - systemInstruction: { - role: 'system', - parts: [{"text": `For example, you are a helpful customer service agent. tell me your name. in 5 pages`}] - }, - }) +// Create customHeaders using Headers +const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" +}); -async function streamGenerateContent() { - const request = { - contents: [{role: 'user', parts: [{text: 'How are you doing today?'}]}], - }; - const streamingResult = await generativeModel.generateContentStream(request); - for await (const item of streamingResult.stream) { - console.log('stream chunk: ', JSON.stringify(item)); +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: customHeaders +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function testModel() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); } - const aggregatedResponse = await streamingResult.response; - console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); - }; - - streamGenerateContent(); \ No newline at end of file +} + +testModel(); \ No newline at end of file From 4273837addb50fa16621529ccdca65f2f9c93503 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:19:01 -0800 Subject: [PATCH 05/21] fix vertex_proxy_route --- .../proxy/vertex_ai_endpoints/vertex_endpoints.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 2bd5b790c..73a38bdf7 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -119,7 +119,6 @@ async def vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): encoded_endpoint = httpx.URL(endpoint).path @@ -128,6 +127,20 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} + # TODO - clean this up before merging + litellm_api_key = request.headers.get("X-Litellm-Api-Key") + api_key_to_use = "" + if litellm_api_key: + api_key_to_use = f"Bearer {litellm_api_key}" + else: + api_key_to_use = request.headers.get("Authorization") + + api_key_to_use = api_key_to_use or "" + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) + vertex_project = None vertex_location = None # Use headers from the incoming request if default_vertex_config is not set From 04c9284da43982e692975f47e9b1ad3c126ad464 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:19:28 -0800 Subject: [PATCH 06/21] use PassThroughStreamingHandler --- .../streaming_handler.py | 164 ++++++++++-------- 1 file changed, 87 insertions(+), 77 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index a88ad34d3..522319aaa 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -27,97 +27,107 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -async def chunk_processor( - response: httpx.Response, - request_body: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - endpoint_type: EndpointType, - start_time: datetime, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, -): - """ - - Yields chunks from the response - - Collect non-empty chunks for post-processing (logging) - """ - try: - if endpoint_type == EndpointType.VERTEX_AI: +class PassThroughStreamingHandler: + + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) yield chunk - else: - collected_chunks: List[str] = [] # List to store all chunks - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue - - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") - - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" # After all chunks are processed, handle post-processing end_time = datetime.now() - await _route_streaming_logging_to_handler( + await PassThroughStreamingHandler._route_streaming_logging_to_handler( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, request_body=request_body or {}, endpoint_type=endpoint_type, start_time=start_time, - all_chunks=collected_chunks, + raw_bytes=raw_bytes, end_time=end_time, ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise - except Exception as e: - verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") - raise + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler - -async def _route_streaming_logging_to_handler( - litellm_logging_obj: LiteLLMLoggingObj, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, - request_body: dict, - endpoint_type: EndpointType, - start_time: datetime, - all_chunks: List[str], - end_time: datetime, -): - """ - Route the logging for the collected chunks to the appropriate handler - - Supported endpoint types: - - Anthropic - - Vertex AI - """ - if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes ) - elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.VERTEX_AI: + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass + + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + + Args: + raw_bytes: List of bytes chunks from aiter.bytes() + + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") + + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines From 7422af75fd799f7063f05e3a754bcf8596795294 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:20:21 -0800 Subject: [PATCH 07/21] fix PassThroughStreamingHandler --- .../proxy/pass_through_endpoints/pass_through_endpoints.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index baf107a16..f60fd0166 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -36,7 +36,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str -from .streaming_handler import chunk_processor +from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging from .types import EndpointType, PassthroughStandardLoggingPayload @@ -448,7 +448,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, @@ -491,7 +491,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, From 49724153725a981597c25145ca9b9d9fe32b97f4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:35:02 -0800 Subject: [PATCH 08/21] use common _create_vertex_response_logging_payload_for_generate_content --- .../anthropic_passthrough_logging_handler.py | 2 +- .../vertex_passthrough_logging_handler.py | 62 ++++++++++++++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 35cff0db3..ad5a98258 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -100,7 +100,7 @@ class AnthropicPassthroughLoggingHandler: kwargs["response_cost"] = response_cost kwargs["model"] = model - # Make standard logging object for Vertex AI + # Make standard logging object for Anthropic standard_logging_object = get_standard_logging_object_payload( kwargs=kwargs, init_response_obj=litellm_model_response, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index fe61f32ee..275a0a119 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -56,8 +56,14 @@ class VertexPassthroughLoggingHandler: encoding=None, ) ) - logging_obj.model = litellm_model_response.model or model - logging_obj.model_call_details["model"] = logging_obj.model + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) await logging_obj.async_success_handler( result=litellm_model_response, @@ -147,6 +153,14 @@ class VertexPassthroughLoggingHandler: "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) return + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) await litellm_logging_obj.async_success_handler( result=complete_streaming_response, start_time=start_time, @@ -193,3 +207,47 @@ class VertexPassthroughLoggingHandler: if match: return match.group(1) return "unknown" + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ + litellm.ModelResponse, litellm.TextCompletionResponse + ], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + return kwargs From 53e82b7f14afee6ffb15327903b7d7cd7587972e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:17:59 -0800 Subject: [PATCH 09/21] test vertex js --- .circleci/config.yml | 41 ++++++++++++++++++- tests/pass_through_tests/test_local_vertex.js | 40 ++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/pass_through_tests/test_local_vertex.js diff --git a/.circleci/config.yml b/.circleci/config.yml index db7c4ef5b..950a6cc0c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1094,6 +1094,26 @@ jobs: working_directory: ~/project steps: - checkout + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Install Docker CLI (In case it's not already installed) command: | @@ -1172,6 +1192,26 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Run tests command: | @@ -1179,7 +1219,6 @@ jobs: ls python -m pytest -vv tests/pass_through_tests/ -x --junitxml=test-results/junit.xml --durations=5 no_output_timeout: 120m - # Store test results - store_test_results: path: test-results diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js new file mode 100644 index 000000000..1f2eaea33 --- /dev/null +++ b/tests/pass_through_tests/test_local_vertex.js @@ -0,0 +1,40 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); + +// Create customHeaders using Headers +const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" +}); + +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: customHeaders +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function testModel() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); + } +} + +testModel(); \ No newline at end of file From d3f23e0528dd6dec8a5d170cb711f3f909f4f4ce Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:40:56 -0800 Subject: [PATCH 10/21] add working vertex jest tests --- tests/pass_through_tests/test_local_vertex.js | 14 ++++ tests/pass_through_tests/test_vertex.js | 40 ----------- tests/pass_through_tests/test_vertex.test.js | 71 +++++++++++++++++++ 3 files changed, 85 insertions(+), 40 deletions(-) delete mode 100644 tests/pass_through_tests/test_vertex.js create mode 100644 tests/pass_through_tests/test_vertex.test.js diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index 1f2eaea33..6e8fbeacd 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -1,5 +1,19 @@ const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + const vertexAI = new VertexAI({ project: 'adroit-crow-413218', location: 'us-central1', diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js deleted file mode 100644 index 1f2eaea33..000000000 --- a/tests/pass_through_tests/test_vertex.js +++ /dev/null @@ -1,40 +0,0 @@ -const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); - -const vertexAI = new VertexAI({ - project: 'adroit-crow-413218', - location: 'us-central1', - apiEndpoint: "localhost:4000/vertex-ai" -}); - -// Create customHeaders using Headers -const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" -}); - -// Use customHeaders in RequestOptions -const requestOptions = { - customHeaders: customHeaders -}; - -const generativeModel = vertexAI.getGenerativeModel( - { model: 'gemini-1.0-pro' }, - requestOptions -); - -async function testModel() { - try { - const request = { - contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], - }; - const streamingResult = await generativeModel.generateContentStream(request); - for await (const item of streamingResult.stream) { - console.log('stream chunk: ', JSON.stringify(item)); - } - const aggregatedResponse = await streamingResult.response; - console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); - } catch (error) { - console.error('Error:', error); - } -} - -testModel(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js new file mode 100644 index 000000000..d651650c2 --- /dev/null +++ b/tests/pass_through_tests/test_vertex.test.js @@ -0,0 +1,71 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + + +describe('Vertex AI Tests', () => { + test('should successfully generate content from Vertex AI', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" + }); + + const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + + const streamingResult = await generativeModel.generateContentStream(request); + + // Add some assertions + expect(streamingResult).toBeDefined(); + + for await (const item of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(item)); + expect(item).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + }); + + + test('should successfully generate non-streaming content from Vertex AI', async () => { + const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); + const customHeaders = new Headers({"X-Litellm-Api-Key": "sk-1234"}); + const requestOptions = {customHeaders: customHeaders}; + const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); + const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; + + const result = await generativeModel.generateContent(request); + expect(result).toBeDefined(); + expect(result.response).toBeDefined(); + console.log('non-streaming response:', JSON.stringify(result.response)); + }); +}); \ No newline at end of file From c7c586c8a6cc45f9e93ecb13a76e0ba0f238ddab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:44:28 -0800 Subject: [PATCH 11/21] move basic bass through test --- .../test_anthropic_passthrough copy.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{anthropic_passthrough/test_anthropic_passthrough.py => pass_through_tests/test_anthropic_passthrough copy.py} (100%) diff --git a/tests/anthropic_passthrough/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough copy.py similarity index 100% rename from tests/anthropic_passthrough/test_anthropic_passthrough.py rename to tests/pass_through_tests/test_anthropic_passthrough copy.py From 4b607e0cc280d073af68aefbd4092ed795045bab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:44:57 -0800 Subject: [PATCH 12/21] use good name for test --- ...assthrough copy.py => test_anthropic_passthrough_python_sdkpy} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/pass_through_tests/{test_anthropic_passthrough copy.py => test_anthropic_passthrough_python_sdkpy} (100%) diff --git a/tests/pass_through_tests/test_anthropic_passthrough copy.py b/tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy similarity index 100% rename from tests/pass_through_tests/test_anthropic_passthrough copy.py rename to tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy From 4b576571a1c6c8a604fb1d0cdf3dc05ac5f7440f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:47:03 -0800 Subject: [PATCH 13/21] test vertex --- tests/pass_through_tests/test_vertex.test.js | 43 ++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js index d651650c2..766050da7 100644 --- a/tests/pass_through_tests/test_vertex.test.js +++ b/tests/pass_through_tests/test_vertex.test.js @@ -1,4 +1,8 @@ const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); +const fs = require('fs'); +const path = require('path'); +const os = require('os'); +const { writeFileSync } = require('fs'); // Import fetch if the SDK uses it @@ -14,6 +18,45 @@ global.fetch = async function patchedFetch(url, options) { return originalFetch(url, options); }; +function loadVertexAiCredentials() { + console.log("loading vertex ai credentials"); + const filepath = path.dirname(__filename); + const vertexKeyPath = path.join(filepath, "vertex_key.json"); + + // Initialize default empty service account data + let serviceAccountKeyData = {}; + + // Try to read existing vertex_key.json + try { + const content = fs.readFileSync(vertexKeyPath, 'utf8'); + if (content && content.trim()) { + serviceAccountKeyData = JSON.parse(content); + } + } catch (error) { + // File doesn't exist or is invalid, continue with empty object + } + + // Update with environment variables + const privateKeyId = process.env.VERTEX_AI_PRIVATE_KEY_ID || ""; + const privateKey = (process.env.VERTEX_AI_PRIVATE_KEY || "").replace(/\\n/g, "\n"); + + serviceAccountKeyData.private_key_id = privateKeyId; + serviceAccountKeyData.private_key = privateKey; + + // Create temporary file + const tempFilePath = path.join(os.tmpdir(), `vertex-credentials-${Date.now()}.json`); + writeFileSync(tempFilePath, JSON.stringify(serviceAccountKeyData, null, 2)); + + // Set environment variable + process.env.GOOGLE_APPLICATION_CREDENTIALS = tempFilePath; +} + +// Run credential loading before tests +beforeAll(() => { + loadVertexAiCredentials(); +}); + + describe('Vertex AI Tests', () => { test('should successfully generate content from Vertex AI', async () => { From 88dbb706c1e93902e44fed87f68e8475a4f58412 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:04:30 -0800 Subject: [PATCH 14/21] test_chunk_processor_yields_raw_bytes --- .../test_unit_test_anthropic_pass_through.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py new file mode 100644 index 000000000..afb77f718 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py @@ -0,0 +1,135 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + +# Import the class we're testing +from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, +) + + +@pytest.fixture +def mock_response(): + return { + "model": "claude-3-opus-20240229", + "content": [{"text": "Hello, world!", "type": "text"}], + "role": "assistant", + } + + +@pytest.fixture +def mock_httpx_response(): + mock_resp = Mock(spec=httpx.Response) + mock_resp.json.return_value = { + "content": [{"text": "Hi! My name is Claude.", "type": "text"}], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-5-sonnet-20241022", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": None, + "type": "message", + "usage": {"input_tokens": 2095, "output_tokens": 503}, + } + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/json"} + return mock_resp + + +@pytest.fixture +def mock_logging_obj(): + logging_obj = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) + + logging_obj.async_success_handler = AsyncMock() + return logging_obj + + +@pytest.mark.asyncio +async def test_anthropic_passthrough_handler( + mock_httpx_response, mock_response, mock_logging_obj +): + """ + Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler + """ + start_time = datetime.now() + end_time = datetime.now() + + await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=mock_httpx_response, + response_body=mock_response, + logging_obj=mock_logging_obj, + url_route="/v1/chat/completions", + result="success", + start_time=start_time, + end_time=end_time, + cache_hit=False, + ) + + # Assert that async_success_handler was called + assert mock_logging_obj.async_success_handler.called + + call_args = mock_logging_obj.async_success_handler.call_args + call_kwargs = call_args.kwargs + print("call_kwargs", call_kwargs) + + # Assert required fields are present in call_kwargs + assert "result" in call_kwargs + assert "start_time" in call_kwargs + assert "end_time" in call_kwargs + assert "cache_hit" in call_kwargs + assert "response_cost" in call_kwargs + assert "model" in call_kwargs + assert "standard_logging_object" in call_kwargs + + # Assert specific values and types + assert isinstance(call_kwargs["result"], litellm.ModelResponse) + assert isinstance(call_kwargs["start_time"], datetime) + assert isinstance(call_kwargs["end_time"], datetime) + assert isinstance(call_kwargs["cache_hit"], bool) + assert isinstance(call_kwargs["response_cost"], float) + assert call_kwargs["model"] == "claude-3-opus-20240229" + assert isinstance(call_kwargs["standard_logging_object"], dict) + + +def test_create_anthropic_response_logging_payload(mock_logging_obj): + # Test the logging payload creation + model_response = litellm.ModelResponse() + model_response.choices = [{"message": {"content": "Test response"}}] + + start_time = datetime.now() + end_time = datetime.now() + + result = ( + AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=model_response, + model="claude-3-opus-20240229", + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=mock_logging_obj, + ) + ) + + assert isinstance(result, dict) + assert "model" in result + assert "response_cost" in result + assert "standard_logging_object" in result From 413092ec1c104f2b9e6a4b4449a7ce9e34fa5545 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:04:45 -0800 Subject: [PATCH 15/21] unit tests for streaming --- .../test_unit_test_streaming.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/pass_through_unit_tests/test_unit_test_streaming.py diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py new file mode 100644 index 000000000..845052382 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -0,0 +1,93 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import httpx +import pytest +import litellm +from typing import AsyncGenerator +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy.pass_through_endpoints.types import EndpointType +from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, +) +from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, +) + + +# Helper function to mock async iteration +async def aiter_mock(iterable): + for item in iterable: + yield item + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_type,url_route", + [ + ( + EndpointType.VERTEX_AI, + "v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent", + ), + (EndpointType.ANTHROPIC, "/v1/messages"), + ], +) +async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): + """ + Test that the chunk_processor yields raw bytes + + This is CRITICAL for pass throughs streaming with Vertex AI and Anthropic + """ + # Mock inputs + response = AsyncMock(spec=httpx.Response) + raw_chunks = [ + b'{"id": "1", "content": "Hello"}', + b'{"id": "2", "content": "World"}', + b'\n\ndata: {"id": "3"}', # Testing different byte formats + ] + + # Mock aiter_bytes to return an async generator + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + response.aiter_bytes = mock_aiter_bytes + + request_body = {"key": "value"} + litellm_logging_obj = MagicMock() + start_time = datetime.now() + passthrough_success_handler_obj = MagicMock() + + # Capture yielded chunks and perform detailed assertions + received_chunks = [] + async for chunk in PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + ): + # Assert each chunk is bytes + assert isinstance(chunk, bytes), f"Chunk should be bytes, got {type(chunk)}" + # Assert no decoding/encoding occurred (chunk should be exactly as input) + assert ( + chunk in raw_chunks + ), f"Chunk {chunk} was modified during processing. For pass throughs streaming, chunks should be raw bytes" + received_chunks.append(chunk) + + # Assert all chunks were processed + assert len(received_chunks) == len(raw_chunks), "Not all chunks were processed" + + # collected chunks all together + assert b"".join(received_chunks) == b"".join( + raw_chunks + ), "Collected chunks do not match raw chunks" From 06da8a5fbc666c9dcc7bb4798a8633a258d1ac0b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:07:45 -0800 Subject: [PATCH 16/21] test_convert_raw_bytes_to_str_lines --- .../test_unit_test_streaming.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py index 845052382..bbbc465fc 100644 --- a/tests/pass_through_unit_tests/test_unit_test_streaming.py +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -91,3 +91,28 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): assert b"".join(received_chunks) == b"".join( raw_chunks ), "Collected chunks do not match raw chunks" + + +def test_convert_raw_bytes_to_str_lines(): + """ + Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings + """ + # Test case 1: Single chunk + raw_bytes = [b'data: {"content": "Hello"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}'] + + # Test case 2: Multiple chunks + raw_bytes = [b'data: {"content": "Hello"}\n', b'data: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] + + # Test case 3: Empty input + raw_bytes = [] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == [] + + # Test case 4: Chunks with empty lines + raw_bytes = [b'data: {"content": "Hello"}\n\n', b'\ndata: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] From 35040f12becb9fd50b5064a3e4321ae2650b03ec Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:15:37 -0800 Subject: [PATCH 17/21] run unit tests 1st --- .circleci/config.yml | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 950a6cc0c..2f9df4d52 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1094,26 +1094,6 @@ jobs: working_directory: ~/project steps: - checkout - # New steps to run Node.js test - - run: - name: Install Node.js - command: | - curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - - sudo apt-get install -y nodejs - node --version - npm --version - - - run: - name: Install Node.js dependencies - command: | - npm install @google-cloud/vertexai - npm install --save-dev jest - - - run: - name: Run Vertex AI tests - command: | - npx jest tests/pass_through_tests/test_vertex.test.js --verbose - no_output_timeout: 30m - run: name: Install Docker CLI (In case it's not already installed) command: | @@ -1157,6 +1137,15 @@ jobs: pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" pip install anthropic + python -m pip install -r requirements.txt + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . @@ -1192,7 +1181,7 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m - # New steps to run Node.js test + # New steps to run Node.js test - run: name: Install Node.js command: | From 77fe5af5b359a84223b535d9a0832f0d13952c50 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:31:58 -0800 Subject: [PATCH 18/21] simplify local --- tests/pass_through_tests/test_local_vertex.js | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index 6e8fbeacd..f86335636 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -20,14 +20,12 @@ const vertexAI = new VertexAI({ apiEndpoint: "localhost:4000/vertex-ai" }); -// Create customHeaders using Headers -const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" -}); // Use customHeaders in RequestOptions const requestOptions = { - customHeaders: customHeaders + customHeaders: new Headers({ + "X-Litellm-Api-Key": "sk-1234" + }) }; const generativeModel = vertexAI.getGenerativeModel( From 7674217e6c2c0127ce971ebabadd356cfa0c0d63 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:40:40 -0800 Subject: [PATCH 19/21] docs add usage example for js --- .../my-website/docs/pass_through/vertex_ai.md | 65 +++++++++++++++++++ tests/pass_through_tests/test_local_vertex.js | 20 +++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index 07b0beb75..e5491159f 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -12,6 +12,71 @@ Looking for the Unified API (OpenAI format) for VertexAI ? [Go here - using vert ::: +Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation). + +Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex-ai` + + +#### **Example Usage** + + + + +```bash +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "contents":[{ + "role": "user", + "parts":[{"text": "How are you doing today?"}] + }] + }' +``` + + + + +```javascript +const { VertexAI } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'your-project-id', // enter your vertex project id + location: 'us-central1', // enter your vertex region + apiEndpoint: "localhost:4000/vertex-ai" // /vertex-ai # note, do not include 'https://' in the url +}); + +const model = vertexAI.getGenerativeModel({ + model: 'gemini-1.0-pro' +}, { + customHeaders: { + "X-Litellm-Api-Key": "sk-1234" // Your litellm Virtual Key + } +}); + +async function generateContent() { + try { + const prompt = { + contents: [{ + role: 'user', + parts: [{ text: 'How are you doing today?' }] + }] + }; + + const response = await model.generateContent(prompt); + console.log('Response:', response); + } catch (error) { + console.error('Error:', error); + } +} + +generateContent(); +``` + + + + + ## Supported API Endpoints - Gemini API diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index f86335636..d5e22b77c 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -33,7 +33,7 @@ const generativeModel = vertexAI.getGenerativeModel( requestOptions ); -async function testModel() { +async function streamingResponse() { try { const request = { contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], @@ -49,4 +49,20 @@ async function testModel() { } } -testModel(); \ No newline at end of file + +async function nonStreamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const response = await generativeModel.generateContent(request); + console.log('non streaming response: ', JSON.stringify(response)); + } catch (error) { + console.error('Error:', error); + } +} + + + +streamingResponse(); +nonStreamingResponse(); \ No newline at end of file From 8aa18b39775f3241a2e72de10a7d1617dd289767 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:44:35 -0800 Subject: [PATCH 20/21] use get_litellm_virtual_key --- .../my-website/docs/pass_through/vertex_ai.md | 2 +- .../vertex_ai_endpoints/vertex_endpoints.py | 26 ++++++++++++------- tests/pass_through_tests/test_local_vertex.js | 2 +- tests/pass_through_tests/test_vertex.test.js | 4 +-- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index e5491159f..03190c839 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -50,7 +50,7 @@ const model = vertexAI.getGenerativeModel({ model: 'gemini-1.0-pro' }, { customHeaders: { - "X-Litellm-Api-Key": "sk-1234" // Your litellm Virtual Key + "x-litellm-api-key": "sk-1234" // Your litellm Virtual Key } }); diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 73a38bdf7..fbf37ce8d 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -126,16 +126,7 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} - - # TODO - clean this up before merging - litellm_api_key = request.headers.get("X-Litellm-Api-Key") - api_key_to_use = "" - if litellm_api_key: - api_key_to_use = f"Bearer {litellm_api_key}" - else: - api_key_to_use = request.headers.get("Authorization") - - api_key_to_use = api_key_to_use or "" + api_key_to_use = get_litellm_virtual_key(request=request) user_api_key_dict = await user_api_key_auth( request=request, api_key=api_key_to_use, @@ -227,3 +218,18 @@ async def vertex_proxy_route( ) return received_value + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index d5e22b77c..7ae9b942a 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -24,7 +24,7 @@ const vertexAI = new VertexAI({ // Use customHeaders in RequestOptions const requestOptions = { customHeaders: new Headers({ - "X-Litellm-Api-Key": "sk-1234" + "x-litellm-api-key": "sk-1234" }) }; diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js index 766050da7..dc457c68a 100644 --- a/tests/pass_through_tests/test_vertex.test.js +++ b/tests/pass_through_tests/test_vertex.test.js @@ -67,7 +67,7 @@ describe('Vertex AI Tests', () => { }); const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" + "x-litellm-api-key": "sk-1234" }); const requestOptions = { @@ -101,7 +101,7 @@ describe('Vertex AI Tests', () => { test('should successfully generate non-streaming content from Vertex AI', async () => { const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); - const customHeaders = new Headers({"X-Litellm-Api-Key": "sk-1234"}); + const customHeaders = new Headers({"x-litellm-api-key": "sk-1234"}); const requestOptions = {customHeaders: customHeaders}; const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; From 0afdba082257a77b234723aee10c482df3c0fb17 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:49:35 -0800 Subject: [PATCH 21/21] add unit tests for vertex pass through --- .../test_unit_test_anthropic.py | 135 ------------------ .../test_unit_test_vertex_pass_through.py | 84 +++++++++++ 2 files changed, 84 insertions(+), 135 deletions(-) delete mode 100644 tests/pass_through_unit_tests/test_unit_test_anthropic.py create mode 100644 tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic.py b/tests/pass_through_unit_tests/test_unit_test_anthropic.py deleted file mode 100644 index afb77f718..000000000 --- a/tests/pass_through_unit_tests/test_unit_test_anthropic.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -import os -import sys -from datetime import datetime -from unittest.mock import AsyncMock, Mock, patch - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path - - -import httpx -import pytest -import litellm -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj - -# Import the class we're testing -from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( - AnthropicPassthroughLoggingHandler, -) - - -@pytest.fixture -def mock_response(): - return { - "model": "claude-3-opus-20240229", - "content": [{"text": "Hello, world!", "type": "text"}], - "role": "assistant", - } - - -@pytest.fixture -def mock_httpx_response(): - mock_resp = Mock(spec=httpx.Response) - mock_resp.json.return_value = { - "content": [{"text": "Hi! My name is Claude.", "type": "text"}], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-5-sonnet-20241022", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": None, - "type": "message", - "usage": {"input_tokens": 2095, "output_tokens": 503}, - } - mock_resp.status_code = 200 - mock_resp.headers = {"Content-Type": "application/json"} - return mock_resp - - -@pytest.fixture -def mock_logging_obj(): - logging_obj = LiteLLMLoggingObj( - model="claude-3-opus-20240229", - messages=[], - stream=False, - call_type="completion", - start_time=datetime.now(), - litellm_call_id="123", - function_id="456", - ) - - logging_obj.async_success_handler = AsyncMock() - return logging_obj - - -@pytest.mark.asyncio -async def test_anthropic_passthrough_handler( - mock_httpx_response, mock_response, mock_logging_obj -): - """ - Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler - """ - start_time = datetime.now() - end_time = datetime.now() - - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( - httpx_response=mock_httpx_response, - response_body=mock_response, - logging_obj=mock_logging_obj, - url_route="/v1/chat/completions", - result="success", - start_time=start_time, - end_time=end_time, - cache_hit=False, - ) - - # Assert that async_success_handler was called - assert mock_logging_obj.async_success_handler.called - - call_args = mock_logging_obj.async_success_handler.call_args - call_kwargs = call_args.kwargs - print("call_kwargs", call_kwargs) - - # Assert required fields are present in call_kwargs - assert "result" in call_kwargs - assert "start_time" in call_kwargs - assert "end_time" in call_kwargs - assert "cache_hit" in call_kwargs - assert "response_cost" in call_kwargs - assert "model" in call_kwargs - assert "standard_logging_object" in call_kwargs - - # Assert specific values and types - assert isinstance(call_kwargs["result"], litellm.ModelResponse) - assert isinstance(call_kwargs["start_time"], datetime) - assert isinstance(call_kwargs["end_time"], datetime) - assert isinstance(call_kwargs["cache_hit"], bool) - assert isinstance(call_kwargs["response_cost"], float) - assert call_kwargs["model"] == "claude-3-opus-20240229" - assert isinstance(call_kwargs["standard_logging_object"], dict) - - -def test_create_anthropic_response_logging_payload(mock_logging_obj): - # Test the logging payload creation - model_response = litellm.ModelResponse() - model_response.choices = [{"message": {"content": "Test response"}}] - - start_time = datetime.now() - end_time = datetime.now() - - result = ( - AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( - litellm_model_response=model_response, - model="claude-3-opus-20240229", - kwargs={}, - start_time=start_time, - end_time=end_time, - logging_obj=mock_logging_obj, - ) - ) - - assert isinstance(result, dict) - assert "model" in result - assert "response_cost" in result - assert "standard_logging_object" in result diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py new file mode 100644 index 000000000..a7b668813 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -0,0 +1,84 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + get_litellm_virtual_key, + vertex_proxy_route, +) + + +@pytest.mark.asyncio +async def test_get_litellm_virtual_key(): + """ + Test that the get_litellm_virtual_key function correctly handles the API key authentication + """ + # Test with x-litellm-api-key + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + # Test with Authorization header + mock_request.headers = {"Authorization": "Bearer auth-key-456"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer auth-key-456" + + # Test with both headers (x-litellm-api-key should take precedence) + mock_request.headers = { + "x-litellm-api-key": "test-key-123", + "Authorization": "Bearer auth-key-456", + } + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_vertex_proxy_route_api_key_auth(): + """ + Critical + + This is how Vertex AI JS SDK will Auth to Litellm Proxy + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123"