From f83708ed4eb9d3e6939e6fcbdac0df9f3750b31b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 12:59:01 -0800 Subject: [PATCH 01/78] stash gemini JS test --- tests/pass_through_tests/test_gemini.js | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/pass_through_tests/test_gemini.js diff --git a/tests/pass_through_tests/test_gemini.js b/tests/pass_through_tests/test_gemini.js new file mode 100644 index 000000000..2b7d6c5c6 --- /dev/null +++ b/tests/pass_through_tests/test_gemini.js @@ -0,0 +1,23 @@ +// const { GoogleGenerativeAI } = require("@google/generative-ai"); + +// const genAI = new GoogleGenerativeAI("sk-1234"); +// const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); + +// const prompt = "Explain how AI works in 2 pages"; + +// async function run() { +// try { +// const result = await model.generateContentStream(prompt, { baseUrl: "http://localhost:4000/gemini" }); +// const response = await result.response; +// console.log(response.text()); +// for await (const chunk of result.stream) { +// const chunkText = chunk.text(); +// console.log(chunkText); +// process.stdout.write(chunkText); +// } +// } catch (error) { +// console.error("Error:", error); +// } +// } + +// run(); \ No newline at end of file From dcab2d0c6f57e566fbee4027a83c94d8af2dfd09 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 12:59:34 -0800 Subject: [PATCH 02/78] add vertex js sdj example --- tests/pass_through_tests/test_vertex.js | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/pass_through_tests/test_vertex.js diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js new file mode 100644 index 000000000..39db0ad0d --- /dev/null +++ b/tests/pass_through_tests/test_vertex.js @@ -0,0 +1,43 @@ + +const { + FunctionDeclarationSchemaType, + HarmBlockThreshold, + HarmCategory, + VertexAI, + RequestOptions + } = require('@google-cloud/vertexai'); + + const project = 'adroit-crow-413218'; + const location = 'us-central1'; + const textModel = 'gemini-1.0-pro'; + const visionModel = 'gemini-1.0-pro-vision'; + + + const vertexAI = new VertexAI({project: project, location: location, apiEndpoint: "localhost:4000/vertex-ai"}); + + // Instantiate Gemini models + const generativeModel = vertexAI.getGenerativeModel({ + model: textModel, + // The following parameters are optional + // They can also be passed to individual content generation requests + safetySettings: [{category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + generationConfig: {maxOutputTokens: 256}, + systemInstruction: { + role: 'system', + parts: [{"text": `For example, you are a helpful customer service agent. tell me your name. in 5 pages`}] + }, + }) + +async function streamGenerateContent() { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + }; + + streamGenerateContent(); \ No newline at end of file From e829b228b24c266481b03b96b063f43e96b1d132 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:18:08 -0800 Subject: [PATCH 03/78] handle vertex pass through separately --- .../streaming_handler.py | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9ba5adfec..a88ad34d3 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -40,35 +40,39 @@ async def chunk_processor( - Yields chunks from the response - Collect non-empty chunks for post-processing (logging) """ - collected_chunks: List[str] = [] # List to store all chunks try: - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue + if endpoint_type == EndpointType.VERTEX_AI: + async for chunk in response.aiter_bytes(): + yield chunk + else: + collected_chunks: List[str] = [] # List to store all chunks + async for chunk in response.aiter_lines(): + verbose_proxy_logger.debug(f"Processing chunk: {chunk}") + if not chunk: + continue - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") + # Handle SSE format - pass through the raw SSE format + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" + # Store the chunk for post-processing + if chunk.strip(): # Only store non-empty chunks + collected_chunks.append(chunk) + yield f"{chunk}\n" - # After all chunks are processed, handle post-processing - end_time = datetime.now() + # After all chunks are processed, handle post-processing + end_time = datetime.now() - await _route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=collected_chunks, - end_time=end_time, - ) + await _route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=collected_chunks, + end_time=end_time, + ) except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") From bbb2e029b562d7924c7186146c59d121b72436b0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:18:23 -0800 Subject: [PATCH 04/78] tes vertex JS sdk --- tests/pass_through_tests/test_vertex.js | 75 ++++++++++++------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js index 39db0ad0d..1f2eaea33 100644 --- a/tests/pass_through_tests/test_vertex.js +++ b/tests/pass_through_tests/test_vertex.js @@ -1,43 +1,40 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); -const { - FunctionDeclarationSchemaType, - HarmBlockThreshold, - HarmCategory, - VertexAI, - RequestOptions - } = require('@google-cloud/vertexai'); - - const project = 'adroit-crow-413218'; - const location = 'us-central1'; - const textModel = 'gemini-1.0-pro'; - const visionModel = 'gemini-1.0-pro-vision'; +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); - - const vertexAI = new VertexAI({project: project, location: location, apiEndpoint: "localhost:4000/vertex-ai"}); - - // Instantiate Gemini models - const generativeModel = vertexAI.getGenerativeModel({ - model: textModel, - // The following parameters are optional - // They can also be passed to individual content generation requests - safetySettings: [{category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], - generationConfig: {maxOutputTokens: 256}, - systemInstruction: { - role: 'system', - parts: [{"text": `For example, you are a helpful customer service agent. tell me your name. in 5 pages`}] - }, - }) +// Create customHeaders using Headers +const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" +}); -async function streamGenerateContent() { - const request = { - contents: [{role: 'user', parts: [{text: 'How are you doing today?'}]}], - }; - const streamingResult = await generativeModel.generateContentStream(request); - for await (const item of streamingResult.stream) { - console.log('stream chunk: ', JSON.stringify(item)); +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: customHeaders +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function testModel() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); } - const aggregatedResponse = await streamingResult.response; - console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); - }; - - streamGenerateContent(); \ No newline at end of file +} + +testModel(); \ No newline at end of file From 4273837addb50fa16621529ccdca65f2f9c93503 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 13:19:01 -0800 Subject: [PATCH 05/78] fix vertex_proxy_route --- .../proxy/vertex_ai_endpoints/vertex_endpoints.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 2bd5b790c..73a38bdf7 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -119,7 +119,6 @@ async def vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): encoded_endpoint = httpx.URL(endpoint).path @@ -128,6 +127,20 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} + # TODO - clean this up before merging + litellm_api_key = request.headers.get("X-Litellm-Api-Key") + api_key_to_use = "" + if litellm_api_key: + api_key_to_use = f"Bearer {litellm_api_key}" + else: + api_key_to_use = request.headers.get("Authorization") + + api_key_to_use = api_key_to_use or "" + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) + vertex_project = None vertex_location = None # Use headers from the incoming request if default_vertex_config is not set From 04c9284da43982e692975f47e9b1ad3c126ad464 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:19:28 -0800 Subject: [PATCH 06/78] use PassThroughStreamingHandler --- .../streaming_handler.py | 164 ++++++++++-------- 1 file changed, 87 insertions(+), 77 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index a88ad34d3..522319aaa 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -27,97 +27,107 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -async def chunk_processor( - response: httpx.Response, - request_body: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - endpoint_type: EndpointType, - start_time: datetime, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, -): - """ - - Yields chunks from the response - - Collect non-empty chunks for post-processing (logging) - """ - try: - if endpoint_type == EndpointType.VERTEX_AI: +class PassThroughStreamingHandler: + + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) yield chunk - else: - collected_chunks: List[str] = [] # List to store all chunks - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue - - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") - - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" # After all chunks are processed, handle post-processing end_time = datetime.now() - await _route_streaming_logging_to_handler( + await PassThroughStreamingHandler._route_streaming_logging_to_handler( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, request_body=request_body or {}, endpoint_type=endpoint_type, start_time=start_time, - all_chunks=collected_chunks, + raw_bytes=raw_bytes, end_time=end_time, ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise - except Exception as e: - verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") - raise + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler - -async def _route_streaming_logging_to_handler( - litellm_logging_obj: LiteLLMLoggingObj, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, - request_body: dict, - endpoint_type: EndpointType, - start_time: datetime, - all_chunks: List[str], - end_time: datetime, -): - """ - Route the logging for the collected chunks to the appropriate handler - - Supported endpoint types: - - Anthropic - - Vertex AI - """ - if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes ) - elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.VERTEX_AI: + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass + + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + + Args: + raw_bytes: List of bytes chunks from aiter.bytes() + + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") + + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines From 7422af75fd799f7063f05e3a754bcf8596795294 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:20:21 -0800 Subject: [PATCH 07/78] fix PassThroughStreamingHandler --- .../proxy/pass_through_endpoints/pass_through_endpoints.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index baf107a16..f60fd0166 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -36,7 +36,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str -from .streaming_handler import chunk_processor +from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging from .types import EndpointType, PassthroughStandardLoggingPayload @@ -448,7 +448,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, @@ -491,7 +491,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, From 49724153725a981597c25145ca9b9d9fe32b97f4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:35:02 -0800 Subject: [PATCH 08/78] use common _create_vertex_response_logging_payload_for_generate_content --- .../anthropic_passthrough_logging_handler.py | 2 +- .../vertex_passthrough_logging_handler.py | 62 ++++++++++++++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 35cff0db3..ad5a98258 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -100,7 +100,7 @@ class AnthropicPassthroughLoggingHandler: kwargs["response_cost"] = response_cost kwargs["model"] = model - # Make standard logging object for Vertex AI + # Make standard logging object for Anthropic standard_logging_object = get_standard_logging_object_payload( kwargs=kwargs, init_response_obj=litellm_model_response, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index fe61f32ee..275a0a119 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -56,8 +56,14 @@ class VertexPassthroughLoggingHandler: encoding=None, ) ) - logging_obj.model = litellm_model_response.model or model - logging_obj.model_call_details["model"] = logging_obj.model + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) await logging_obj.async_success_handler( result=litellm_model_response, @@ -147,6 +153,14 @@ class VertexPassthroughLoggingHandler: "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) return + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) await litellm_logging_obj.async_success_handler( result=complete_streaming_response, start_time=start_time, @@ -193,3 +207,47 @@ class VertexPassthroughLoggingHandler: if match: return match.group(1) return "unknown" + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ + litellm.ModelResponse, litellm.TextCompletionResponse + ], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + return kwargs From 53e82b7f14afee6ffb15327903b7d7cd7587972e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:17:59 -0800 Subject: [PATCH 09/78] test vertex js --- .circleci/config.yml | 41 ++++++++++++++++++- tests/pass_through_tests/test_local_vertex.js | 40 ++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/pass_through_tests/test_local_vertex.js diff --git a/.circleci/config.yml b/.circleci/config.yml index db7c4ef5b..950a6cc0c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1094,6 +1094,26 @@ jobs: working_directory: ~/project steps: - checkout + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Install Docker CLI (In case it's not already installed) command: | @@ -1172,6 +1192,26 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Run tests command: | @@ -1179,7 +1219,6 @@ jobs: ls python -m pytest -vv tests/pass_through_tests/ -x --junitxml=test-results/junit.xml --durations=5 no_output_timeout: 120m - # Store test results - store_test_results: path: test-results diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js new file mode 100644 index 000000000..1f2eaea33 --- /dev/null +++ b/tests/pass_through_tests/test_local_vertex.js @@ -0,0 +1,40 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); + +// Create customHeaders using Headers +const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" +}); + +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: customHeaders +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function testModel() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); + } +} + +testModel(); \ No newline at end of file From d3f23e0528dd6dec8a5d170cb711f3f909f4f4ce Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:40:56 -0800 Subject: [PATCH 10/78] add working vertex jest tests --- tests/pass_through_tests/test_local_vertex.js | 14 ++++ tests/pass_through_tests/test_vertex.js | 40 ----------- tests/pass_through_tests/test_vertex.test.js | 71 +++++++++++++++++++ 3 files changed, 85 insertions(+), 40 deletions(-) delete mode 100644 tests/pass_through_tests/test_vertex.js create mode 100644 tests/pass_through_tests/test_vertex.test.js diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index 1f2eaea33..6e8fbeacd 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -1,5 +1,19 @@ const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + const vertexAI = new VertexAI({ project: 'adroit-crow-413218', location: 'us-central1', diff --git a/tests/pass_through_tests/test_vertex.js b/tests/pass_through_tests/test_vertex.js deleted file mode 100644 index 1f2eaea33..000000000 --- a/tests/pass_through_tests/test_vertex.js +++ /dev/null @@ -1,40 +0,0 @@ -const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); - -const vertexAI = new VertexAI({ - project: 'adroit-crow-413218', - location: 'us-central1', - apiEndpoint: "localhost:4000/vertex-ai" -}); - -// Create customHeaders using Headers -const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" -}); - -// Use customHeaders in RequestOptions -const requestOptions = { - customHeaders: customHeaders -}; - -const generativeModel = vertexAI.getGenerativeModel( - { model: 'gemini-1.0-pro' }, - requestOptions -); - -async function testModel() { - try { - const request = { - contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], - }; - const streamingResult = await generativeModel.generateContentStream(request); - for await (const item of streamingResult.stream) { - console.log('stream chunk: ', JSON.stringify(item)); - } - const aggregatedResponse = await streamingResult.response; - console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); - } catch (error) { - console.error('Error:', error); - } -} - -testModel(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js new file mode 100644 index 000000000..d651650c2 --- /dev/null +++ b/tests/pass_through_tests/test_vertex.test.js @@ -0,0 +1,71 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + + +describe('Vertex AI Tests', () => { + test('should successfully generate content from Vertex AI', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" + }); + + const customHeaders = new Headers({ + "X-Litellm-Api-Key": "sk-1234" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + + const streamingResult = await generativeModel.generateContentStream(request); + + // Add some assertions + expect(streamingResult).toBeDefined(); + + for await (const item of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(item)); + expect(item).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + }); + + + test('should successfully generate non-streaming content from Vertex AI', async () => { + const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); + const customHeaders = new Headers({"X-Litellm-Api-Key": "sk-1234"}); + const requestOptions = {customHeaders: customHeaders}; + const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); + const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; + + const result = await generativeModel.generateContent(request); + expect(result).toBeDefined(); + expect(result.response).toBeDefined(); + console.log('non-streaming response:', JSON.stringify(result.response)); + }); +}); \ No newline at end of file From c7c586c8a6cc45f9e93ecb13a76e0ba0f238ddab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:44:28 -0800 Subject: [PATCH 11/78] move basic bass through test --- .../test_anthropic_passthrough copy.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{anthropic_passthrough/test_anthropic_passthrough.py => pass_through_tests/test_anthropic_passthrough copy.py} (100%) diff --git a/tests/anthropic_passthrough/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough copy.py similarity index 100% rename from tests/anthropic_passthrough/test_anthropic_passthrough.py rename to tests/pass_through_tests/test_anthropic_passthrough copy.py From 4b607e0cc280d073af68aefbd4092ed795045bab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:44:57 -0800 Subject: [PATCH 12/78] use good name for test --- ...assthrough copy.py => test_anthropic_passthrough_python_sdkpy} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/pass_through_tests/{test_anthropic_passthrough copy.py => test_anthropic_passthrough_python_sdkpy} (100%) diff --git a/tests/pass_through_tests/test_anthropic_passthrough copy.py b/tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy similarity index 100% rename from tests/pass_through_tests/test_anthropic_passthrough copy.py rename to tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy From 4b576571a1c6c8a604fb1d0cdf3dc05ac5f7440f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 15:47:03 -0800 Subject: [PATCH 13/78] test vertex --- tests/pass_through_tests/test_vertex.test.js | 43 ++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js index d651650c2..766050da7 100644 --- a/tests/pass_through_tests/test_vertex.test.js +++ b/tests/pass_through_tests/test_vertex.test.js @@ -1,4 +1,8 @@ const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); +const fs = require('fs'); +const path = require('path'); +const os = require('os'); +const { writeFileSync } = require('fs'); // Import fetch if the SDK uses it @@ -14,6 +18,45 @@ global.fetch = async function patchedFetch(url, options) { return originalFetch(url, options); }; +function loadVertexAiCredentials() { + console.log("loading vertex ai credentials"); + const filepath = path.dirname(__filename); + const vertexKeyPath = path.join(filepath, "vertex_key.json"); + + // Initialize default empty service account data + let serviceAccountKeyData = {}; + + // Try to read existing vertex_key.json + try { + const content = fs.readFileSync(vertexKeyPath, 'utf8'); + if (content && content.trim()) { + serviceAccountKeyData = JSON.parse(content); + } + } catch (error) { + // File doesn't exist or is invalid, continue with empty object + } + + // Update with environment variables + const privateKeyId = process.env.VERTEX_AI_PRIVATE_KEY_ID || ""; + const privateKey = (process.env.VERTEX_AI_PRIVATE_KEY || "").replace(/\\n/g, "\n"); + + serviceAccountKeyData.private_key_id = privateKeyId; + serviceAccountKeyData.private_key = privateKey; + + // Create temporary file + const tempFilePath = path.join(os.tmpdir(), `vertex-credentials-${Date.now()}.json`); + writeFileSync(tempFilePath, JSON.stringify(serviceAccountKeyData, null, 2)); + + // Set environment variable + process.env.GOOGLE_APPLICATION_CREDENTIALS = tempFilePath; +} + +// Run credential loading before tests +beforeAll(() => { + loadVertexAiCredentials(); +}); + + describe('Vertex AI Tests', () => { test('should successfully generate content from Vertex AI', async () => { From 88dbb706c1e93902e44fed87f68e8475a4f58412 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:04:30 -0800 Subject: [PATCH 14/78] test_chunk_processor_yields_raw_bytes --- .../test_unit_test_anthropic_pass_through.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py new file mode 100644 index 000000000..afb77f718 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py @@ -0,0 +1,135 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + +# Import the class we're testing +from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, +) + + +@pytest.fixture +def mock_response(): + return { + "model": "claude-3-opus-20240229", + "content": [{"text": "Hello, world!", "type": "text"}], + "role": "assistant", + } + + +@pytest.fixture +def mock_httpx_response(): + mock_resp = Mock(spec=httpx.Response) + mock_resp.json.return_value = { + "content": [{"text": "Hi! My name is Claude.", "type": "text"}], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-5-sonnet-20241022", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": None, + "type": "message", + "usage": {"input_tokens": 2095, "output_tokens": 503}, + } + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/json"} + return mock_resp + + +@pytest.fixture +def mock_logging_obj(): + logging_obj = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) + + logging_obj.async_success_handler = AsyncMock() + return logging_obj + + +@pytest.mark.asyncio +async def test_anthropic_passthrough_handler( + mock_httpx_response, mock_response, mock_logging_obj +): + """ + Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler + """ + start_time = datetime.now() + end_time = datetime.now() + + await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=mock_httpx_response, + response_body=mock_response, + logging_obj=mock_logging_obj, + url_route="/v1/chat/completions", + result="success", + start_time=start_time, + end_time=end_time, + cache_hit=False, + ) + + # Assert that async_success_handler was called + assert mock_logging_obj.async_success_handler.called + + call_args = mock_logging_obj.async_success_handler.call_args + call_kwargs = call_args.kwargs + print("call_kwargs", call_kwargs) + + # Assert required fields are present in call_kwargs + assert "result" in call_kwargs + assert "start_time" in call_kwargs + assert "end_time" in call_kwargs + assert "cache_hit" in call_kwargs + assert "response_cost" in call_kwargs + assert "model" in call_kwargs + assert "standard_logging_object" in call_kwargs + + # Assert specific values and types + assert isinstance(call_kwargs["result"], litellm.ModelResponse) + assert isinstance(call_kwargs["start_time"], datetime) + assert isinstance(call_kwargs["end_time"], datetime) + assert isinstance(call_kwargs["cache_hit"], bool) + assert isinstance(call_kwargs["response_cost"], float) + assert call_kwargs["model"] == "claude-3-opus-20240229" + assert isinstance(call_kwargs["standard_logging_object"], dict) + + +def test_create_anthropic_response_logging_payload(mock_logging_obj): + # Test the logging payload creation + model_response = litellm.ModelResponse() + model_response.choices = [{"message": {"content": "Test response"}}] + + start_time = datetime.now() + end_time = datetime.now() + + result = ( + AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=model_response, + model="claude-3-opus-20240229", + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=mock_logging_obj, + ) + ) + + assert isinstance(result, dict) + assert "model" in result + assert "response_cost" in result + assert "standard_logging_object" in result From 413092ec1c104f2b9e6a4b4449a7ce9e34fa5545 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:04:45 -0800 Subject: [PATCH 15/78] unit tests for streaming --- .../test_unit_test_streaming.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/pass_through_unit_tests/test_unit_test_streaming.py diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py new file mode 100644 index 000000000..845052382 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -0,0 +1,93 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import httpx +import pytest +import litellm +from typing import AsyncGenerator +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy.pass_through_endpoints.types import EndpointType +from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, +) +from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, +) + + +# Helper function to mock async iteration +async def aiter_mock(iterable): + for item in iterable: + yield item + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_type,url_route", + [ + ( + EndpointType.VERTEX_AI, + "v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent", + ), + (EndpointType.ANTHROPIC, "/v1/messages"), + ], +) +async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): + """ + Test that the chunk_processor yields raw bytes + + This is CRITICAL for pass throughs streaming with Vertex AI and Anthropic + """ + # Mock inputs + response = AsyncMock(spec=httpx.Response) + raw_chunks = [ + b'{"id": "1", "content": "Hello"}', + b'{"id": "2", "content": "World"}', + b'\n\ndata: {"id": "3"}', # Testing different byte formats + ] + + # Mock aiter_bytes to return an async generator + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + response.aiter_bytes = mock_aiter_bytes + + request_body = {"key": "value"} + litellm_logging_obj = MagicMock() + start_time = datetime.now() + passthrough_success_handler_obj = MagicMock() + + # Capture yielded chunks and perform detailed assertions + received_chunks = [] + async for chunk in PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + ): + # Assert each chunk is bytes + assert isinstance(chunk, bytes), f"Chunk should be bytes, got {type(chunk)}" + # Assert no decoding/encoding occurred (chunk should be exactly as input) + assert ( + chunk in raw_chunks + ), f"Chunk {chunk} was modified during processing. For pass throughs streaming, chunks should be raw bytes" + received_chunks.append(chunk) + + # Assert all chunks were processed + assert len(received_chunks) == len(raw_chunks), "Not all chunks were processed" + + # collected chunks all together + assert b"".join(received_chunks) == b"".join( + raw_chunks + ), "Collected chunks do not match raw chunks" From 06da8a5fbc666c9dcc7bb4798a8633a258d1ac0b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:07:45 -0800 Subject: [PATCH 16/78] test_convert_raw_bytes_to_str_lines --- .../test_unit_test_streaming.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py index 845052382..bbbc465fc 100644 --- a/tests/pass_through_unit_tests/test_unit_test_streaming.py +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -91,3 +91,28 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): assert b"".join(received_chunks) == b"".join( raw_chunks ), "Collected chunks do not match raw chunks" + + +def test_convert_raw_bytes_to_str_lines(): + """ + Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings + """ + # Test case 1: Single chunk + raw_bytes = [b'data: {"content": "Hello"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}'] + + # Test case 2: Multiple chunks + raw_bytes = [b'data: {"content": "Hello"}\n', b'data: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] + + # Test case 3: Empty input + raw_bytes = [] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == [] + + # Test case 4: Chunks with empty lines + raw_bytes = [b'data: {"content": "Hello"}\n\n', b'\ndata: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] From 35040f12becb9fd50b5064a3e4321ae2650b03ec Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:15:37 -0800 Subject: [PATCH 17/78] run unit tests 1st --- .circleci/config.yml | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 950a6cc0c..2f9df4d52 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1094,26 +1094,6 @@ jobs: working_directory: ~/project steps: - checkout - # New steps to run Node.js test - - run: - name: Install Node.js - command: | - curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - - sudo apt-get install -y nodejs - node --version - npm --version - - - run: - name: Install Node.js dependencies - command: | - npm install @google-cloud/vertexai - npm install --save-dev jest - - - run: - name: Run Vertex AI tests - command: | - npx jest tests/pass_through_tests/test_vertex.test.js --verbose - no_output_timeout: 30m - run: name: Install Docker CLI (In case it's not already installed) command: | @@ -1157,6 +1137,15 @@ jobs: pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" pip install anthropic + python -m pip install -r requirements.txt + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . @@ -1192,7 +1181,7 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m - # New steps to run Node.js test + # New steps to run Node.js test - run: name: Install Node.js command: | From 377cfeb24f3e25edb3454e41c3fa69b75476883c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:20:16 -0800 Subject: [PATCH 18/78] add pass_through_unit_testing --- .circleci/config.yml | 50 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index db7c4ef5b..3b63f7487 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -625,6 +625,48 @@ jobs: paths: - llm_translation_coverage.xml - llm_translation_coverage + pass_through_unit_testing: + docker: + - image: cimg/python:3.11 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + + steps: + - checkout + - run: + name: Install Dependencies + command: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + pip install "pytest==7.3.1" + pip install "pytest-retry==1.6.3" + pip install "pytest-cov==5.0.0" + pip install "pytest-asyncio==0.21.1" + pip install "respx==0.21.1" + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m + - run: + name: Rename the coverage files + command: | + mv coverage.xml pass_through_unit_tests_coverage.xml + mv .coverage pass_through_unit_tests_coverage + + # Store test results + - store_test_results: + path: test-results + - persist_to_workspace: + root: . + paths: + - pass_through_unit_tests_coverage.xml + - pass_through_unit_tests_coverage image_gen_testing: docker: - image: cimg/python:3.11 @@ -1494,6 +1536,12 @@ workflows: only: - main - /litellm_.*/ + - pass_through_unit_testing: + filters: + branches: + only: + - main + - /litellm_.*/ - image_gen_testing: filters: branches: @@ -1509,6 +1557,7 @@ workflows: - upload-coverage: requires: - llm_translation_testing + - pass_through_unit_testing - image_gen_testing - logging_testing - litellm_router_testing @@ -1549,6 +1598,7 @@ workflows: - load_testing - test_bad_database_url - llm_translation_testing + - pass_through_unit_testing - image_gen_testing - logging_testing - litellm_router_testing From 5930c42e74d34b580792d2047bf3f157debd9722 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:21:22 -0800 Subject: [PATCH 19/78] fix coverage --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3b63f7487..e86c1cb56 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -965,7 +965,7 @@ jobs: command: | pwd ls - python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests + python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests no_output_timeout: 120m # Store test results @@ -1247,7 +1247,7 @@ jobs: python -m venv venv . venv/bin/activate pip install coverage - coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage + coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage coverage xml - codecov/upload: file: ./coverage.xml From 77fe5af5b359a84223b535d9a0832f0d13952c50 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:31:58 -0800 Subject: [PATCH 20/78] simplify local --- tests/pass_through_tests/test_local_vertex.js | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index 6e8fbeacd..f86335636 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -20,14 +20,12 @@ const vertexAI = new VertexAI({ apiEndpoint: "localhost:4000/vertex-ai" }); -// Create customHeaders using Headers -const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" -}); // Use customHeaders in RequestOptions const requestOptions = { - customHeaders: customHeaders + customHeaders: new Headers({ + "X-Litellm-Api-Key": "sk-1234" + }) }; const generativeModel = vertexAI.getGenerativeModel( From 7674217e6c2c0127ce971ebabadd356cfa0c0d63 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:40:40 -0800 Subject: [PATCH 21/78] docs add usage example for js --- .../my-website/docs/pass_through/vertex_ai.md | 65 +++++++++++++++++++ tests/pass_through_tests/test_local_vertex.js | 20 +++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index 07b0beb75..e5491159f 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -12,6 +12,71 @@ Looking for the Unified API (OpenAI format) for VertexAI ? [Go here - using vert ::: +Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation). + +Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex-ai` + + +#### **Example Usage** + + + + +```bash +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "contents":[{ + "role": "user", + "parts":[{"text": "How are you doing today?"}] + }] + }' +``` + + + + +```javascript +const { VertexAI } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'your-project-id', // enter your vertex project id + location: 'us-central1', // enter your vertex region + apiEndpoint: "localhost:4000/vertex-ai" // /vertex-ai # note, do not include 'https://' in the url +}); + +const model = vertexAI.getGenerativeModel({ + model: 'gemini-1.0-pro' +}, { + customHeaders: { + "X-Litellm-Api-Key": "sk-1234" // Your litellm Virtual Key + } +}); + +async function generateContent() { + try { + const prompt = { + contents: [{ + role: 'user', + parts: [{ text: 'How are you doing today?' }] + }] + }; + + const response = await model.generateContent(prompt); + console.log('Response:', response); + } catch (error) { + console.error('Error:', error); + } +} + +generateContent(); +``` + + + + + ## Supported API Endpoints - Gemini API diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index f86335636..d5e22b77c 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -33,7 +33,7 @@ const generativeModel = vertexAI.getGenerativeModel( requestOptions ); -async function testModel() { +async function streamingResponse() { try { const request = { contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], @@ -49,4 +49,20 @@ async function testModel() { } } -testModel(); \ No newline at end of file + +async function nonStreamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const response = await generativeModel.generateContent(request); + console.log('non streaming response: ', JSON.stringify(response)); + } catch (error) { + console.error('Error:', error); + } +} + + + +streamingResponse(); +nonStreamingResponse(); \ No newline at end of file From 8aa18b39775f3241a2e72de10a7d1617dd289767 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:44:35 -0800 Subject: [PATCH 22/78] use get_litellm_virtual_key --- .../my-website/docs/pass_through/vertex_ai.md | 2 +- .../vertex_ai_endpoints/vertex_endpoints.py | 26 ++++++++++++------- tests/pass_through_tests/test_local_vertex.js | 2 +- tests/pass_through_tests/test_vertex.test.js | 4 +-- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index e5491159f..03190c839 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -50,7 +50,7 @@ const model = vertexAI.getGenerativeModel({ model: 'gemini-1.0-pro' }, { customHeaders: { - "X-Litellm-Api-Key": "sk-1234" // Your litellm Virtual Key + "x-litellm-api-key": "sk-1234" // Your litellm Virtual Key } }); diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 73a38bdf7..fbf37ce8d 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -126,16 +126,7 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} - - # TODO - clean this up before merging - litellm_api_key = request.headers.get("X-Litellm-Api-Key") - api_key_to_use = "" - if litellm_api_key: - api_key_to_use = f"Bearer {litellm_api_key}" - else: - api_key_to_use = request.headers.get("Authorization") - - api_key_to_use = api_key_to_use or "" + api_key_to_use = get_litellm_virtual_key(request=request) user_api_key_dict = await user_api_key_auth( request=request, api_key=api_key_to_use, @@ -227,3 +218,18 @@ async def vertex_proxy_route( ) return received_value + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js index d5e22b77c..7ae9b942a 100644 --- a/tests/pass_through_tests/test_local_vertex.js +++ b/tests/pass_through_tests/test_local_vertex.js @@ -24,7 +24,7 @@ const vertexAI = new VertexAI({ // Use customHeaders in RequestOptions const requestOptions = { customHeaders: new Headers({ - "X-Litellm-Api-Key": "sk-1234" + "x-litellm-api-key": "sk-1234" }) }; diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js index 766050da7..dc457c68a 100644 --- a/tests/pass_through_tests/test_vertex.test.js +++ b/tests/pass_through_tests/test_vertex.test.js @@ -67,7 +67,7 @@ describe('Vertex AI Tests', () => { }); const customHeaders = new Headers({ - "X-Litellm-Api-Key": "sk-1234" + "x-litellm-api-key": "sk-1234" }); const requestOptions = { @@ -101,7 +101,7 @@ describe('Vertex AI Tests', () => { test('should successfully generate non-streaming content from Vertex AI', async () => { const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); - const customHeaders = new Headers({"X-Litellm-Api-Key": "sk-1234"}); + const customHeaders = new Headers({"x-litellm-api-key": "sk-1234"}); const requestOptions = {customHeaders: customHeaders}; const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; From 0afdba082257a77b234723aee10c482df3c0fb17 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:49:35 -0800 Subject: [PATCH 23/78] add unit tests for vertex pass through --- .../test_unit_test_anthropic.py | 135 ------------------ .../test_unit_test_vertex_pass_through.py | 84 +++++++++++ 2 files changed, 84 insertions(+), 135 deletions(-) delete mode 100644 tests/pass_through_unit_tests/test_unit_test_anthropic.py create mode 100644 tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic.py b/tests/pass_through_unit_tests/test_unit_test_anthropic.py deleted file mode 100644 index afb77f718..000000000 --- a/tests/pass_through_unit_tests/test_unit_test_anthropic.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -import os -import sys -from datetime import datetime -from unittest.mock import AsyncMock, Mock, patch - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path - - -import httpx -import pytest -import litellm -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj - -# Import the class we're testing -from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( - AnthropicPassthroughLoggingHandler, -) - - -@pytest.fixture -def mock_response(): - return { - "model": "claude-3-opus-20240229", - "content": [{"text": "Hello, world!", "type": "text"}], - "role": "assistant", - } - - -@pytest.fixture -def mock_httpx_response(): - mock_resp = Mock(spec=httpx.Response) - mock_resp.json.return_value = { - "content": [{"text": "Hi! My name is Claude.", "type": "text"}], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-5-sonnet-20241022", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": None, - "type": "message", - "usage": {"input_tokens": 2095, "output_tokens": 503}, - } - mock_resp.status_code = 200 - mock_resp.headers = {"Content-Type": "application/json"} - return mock_resp - - -@pytest.fixture -def mock_logging_obj(): - logging_obj = LiteLLMLoggingObj( - model="claude-3-opus-20240229", - messages=[], - stream=False, - call_type="completion", - start_time=datetime.now(), - litellm_call_id="123", - function_id="456", - ) - - logging_obj.async_success_handler = AsyncMock() - return logging_obj - - -@pytest.mark.asyncio -async def test_anthropic_passthrough_handler( - mock_httpx_response, mock_response, mock_logging_obj -): - """ - Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler - """ - start_time = datetime.now() - end_time = datetime.now() - - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( - httpx_response=mock_httpx_response, - response_body=mock_response, - logging_obj=mock_logging_obj, - url_route="/v1/chat/completions", - result="success", - start_time=start_time, - end_time=end_time, - cache_hit=False, - ) - - # Assert that async_success_handler was called - assert mock_logging_obj.async_success_handler.called - - call_args = mock_logging_obj.async_success_handler.call_args - call_kwargs = call_args.kwargs - print("call_kwargs", call_kwargs) - - # Assert required fields are present in call_kwargs - assert "result" in call_kwargs - assert "start_time" in call_kwargs - assert "end_time" in call_kwargs - assert "cache_hit" in call_kwargs - assert "response_cost" in call_kwargs - assert "model" in call_kwargs - assert "standard_logging_object" in call_kwargs - - # Assert specific values and types - assert isinstance(call_kwargs["result"], litellm.ModelResponse) - assert isinstance(call_kwargs["start_time"], datetime) - assert isinstance(call_kwargs["end_time"], datetime) - assert isinstance(call_kwargs["cache_hit"], bool) - assert isinstance(call_kwargs["response_cost"], float) - assert call_kwargs["model"] == "claude-3-opus-20240229" - assert isinstance(call_kwargs["standard_logging_object"], dict) - - -def test_create_anthropic_response_logging_payload(mock_logging_obj): - # Test the logging payload creation - model_response = litellm.ModelResponse() - model_response.choices = [{"message": {"content": "Test response"}}] - - start_time = datetime.now() - end_time = datetime.now() - - result = ( - AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( - litellm_model_response=model_response, - model="claude-3-opus-20240229", - kwargs={}, - start_time=start_time, - end_time=end_time, - logging_obj=mock_logging_obj, - ) - ) - - assert isinstance(result, dict) - assert "model" in result - assert "response_cost" in result - assert "standard_logging_object" in result diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py new file mode 100644 index 000000000..a7b668813 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -0,0 +1,84 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + get_litellm_virtual_key, + vertex_proxy_route, +) + + +@pytest.mark.asyncio +async def test_get_litellm_virtual_key(): + """ + Test that the get_litellm_virtual_key function correctly handles the API key authentication + """ + # Test with x-litellm-api-key + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + # Test with Authorization header + mock_request.headers = {"Authorization": "Bearer auth-key-456"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer auth-key-456" + + # Test with both headers (x-litellm-api-key should take precedence) + mock_request.headers = { + "x-litellm-api-key": "test-key-123", + "Authorization": "Bearer auth-key-456", + } + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_vertex_proxy_route_api_key_auth(): + """ + Critical + + This is how Vertex AI JS SDK will Auth to Litellm Proxy + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123" From b2b3e40d13d1e424efe7f4bae83e341e29ac009d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 16:50:10 -0800 Subject: [PATCH 24/78] (feat) use `@google-cloud/vertexai` js sdk with litellm (#6873) * stash gemini JS test * add vertex js sdj example * handle vertex pass through separately * tes vertex JS sdk * fix vertex_proxy_route * use PassThroughStreamingHandler * fix PassThroughStreamingHandler * use common _create_vertex_response_logging_payload_for_generate_content * test vertex js * add working vertex jest tests * move basic bass through test * use good name for test * test vertex * test_chunk_processor_yields_raw_bytes * unit tests for streaming * test_convert_raw_bytes_to_str_lines * run unit tests 1st * simplify local * docs add usage example for js * use get_litellm_virtual_key * add unit tests for vertex pass through --- .circleci/config.yml | 30 ++- .../my-website/docs/pass_through/vertex_ai.md | 65 +++++++ .../anthropic_passthrough_logging_handler.py | 2 +- .../vertex_passthrough_logging_handler.py | 62 +++++- .../pass_through_endpoints.py | 6 +- .../streaming_handler.py | 176 ++++++++++-------- .../vertex_ai_endpoints/vertex_endpoints.py | 21 ++- .../test_anthropic_passthrough_python_sdkpy} | 0 tests/pass_through_tests/test_gemini.js | 23 +++ tests/pass_through_tests/test_local_vertex.js | 68 +++++++ tests/pass_through_tests/test_vertex.test.js | 114 ++++++++++++ ... test_unit_test_anthropic_pass_through.py} | 0 .../test_unit_test_streaming.py | 118 ++++++++++++ .../test_unit_test_vertex_pass_through.py | 84 +++++++++ 14 files changed, 680 insertions(+), 89 deletions(-) rename tests/{anthropic_passthrough/test_anthropic_passthrough.py => pass_through_tests/test_anthropic_passthrough_python_sdkpy} (100%) create mode 100644 tests/pass_through_tests/test_gemini.js create mode 100644 tests/pass_through_tests/test_local_vertex.js create mode 100644 tests/pass_through_tests/test_vertex.test.js rename tests/pass_through_unit_tests/{test_unit_test_anthropic.py => test_unit_test_anthropic_pass_through.py} (100%) create mode 100644 tests/pass_through_unit_tests/test_unit_test_streaming.py create mode 100644 tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py diff --git a/.circleci/config.yml b/.circleci/config.yml index e86c1cb56..1d7ed7602 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1179,6 +1179,15 @@ jobs: pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" pip install anthropic + python -m pip install -r requirements.txt + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . @@ -1214,6 +1223,26 @@ jobs: - run: name: Wait for app to be ready command: dockerize -wait http://localhost:4000 -timeout 5m + # New steps to run Node.js test + - run: + name: Install Node.js + command: | + curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash - + sudo apt-get install -y nodejs + node --version + npm --version + + - run: + name: Install Node.js dependencies + command: | + npm install @google-cloud/vertexai + npm install --save-dev jest + + - run: + name: Run Vertex AI tests + command: | + npx jest tests/pass_through_tests/test_vertex.test.js --verbose + no_output_timeout: 30m - run: name: Run tests command: | @@ -1221,7 +1250,6 @@ jobs: ls python -m pytest -vv tests/pass_through_tests/ -x --junitxml=test-results/junit.xml --durations=5 no_output_timeout: 120m - # Store test results - store_test_results: path: test-results diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index 07b0beb75..03190c839 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -12,6 +12,71 @@ Looking for the Unified API (OpenAI format) for VertexAI ? [Go here - using vert ::: +Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation). + +Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex-ai` + + +#### **Example Usage** + + + + +```bash +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "contents":[{ + "role": "user", + "parts":[{"text": "How are you doing today?"}] + }] + }' +``` + + + + +```javascript +const { VertexAI } = require('@google-cloud/vertexai'); + +const vertexAI = new VertexAI({ + project: 'your-project-id', // enter your vertex project id + location: 'us-central1', // enter your vertex region + apiEndpoint: "localhost:4000/vertex-ai" // /vertex-ai # note, do not include 'https://' in the url +}); + +const model = vertexAI.getGenerativeModel({ + model: 'gemini-1.0-pro' +}, { + customHeaders: { + "x-litellm-api-key": "sk-1234" // Your litellm Virtual Key + } +}); + +async function generateContent() { + try { + const prompt = { + contents: [{ + role: 'user', + parts: [{ text: 'How are you doing today?' }] + }] + }; + + const response = await model.generateContent(prompt); + console.log('Response:', response); + } catch (error) { + console.error('Error:', error); + } +} + +generateContent(); +``` + + + + + ## Supported API Endpoints - Gemini API diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 35cff0db3..ad5a98258 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -100,7 +100,7 @@ class AnthropicPassthroughLoggingHandler: kwargs["response_cost"] = response_cost kwargs["model"] = model - # Make standard logging object for Vertex AI + # Make standard logging object for Anthropic standard_logging_object = get_standard_logging_object_payload( kwargs=kwargs, init_response_obj=litellm_model_response, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index fe61f32ee..275a0a119 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -56,8 +56,14 @@ class VertexPassthroughLoggingHandler: encoding=None, ) ) - logging_obj.model = litellm_model_response.model or model - logging_obj.model_call_details["model"] = logging_obj.model + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) await logging_obj.async_success_handler( result=litellm_model_response, @@ -147,6 +153,14 @@ class VertexPassthroughLoggingHandler: "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) return + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) await litellm_logging_obj.async_success_handler( result=complete_streaming_response, start_time=start_time, @@ -193,3 +207,47 @@ class VertexPassthroughLoggingHandler: if match: return match.group(1) return "unknown" + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ + litellm.ModelResponse, litellm.TextCompletionResponse + ], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + return kwargs diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index baf107a16..f60fd0166 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -36,7 +36,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str -from .streaming_handler import chunk_processor +from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging from .types import EndpointType, PassthroughStandardLoggingPayload @@ -448,7 +448,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, @@ -491,7 +491,7 @@ async def pass_through_request( # noqa: PLR0915 ) return StreamingResponse( - chunk_processor( + PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9ba5adfec..522319aaa 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -27,93 +27,107 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -async def chunk_processor( - response: httpx.Response, - request_body: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - endpoint_type: EndpointType, - start_time: datetime, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, -): - """ - - Yields chunks from the response - - Collect non-empty chunks for post-processing (logging) - """ - collected_chunks: List[str] = [] # List to store all chunks - try: - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue +class PassThroughStreamingHandler: - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] + async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) + yield chunk - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" + # After all chunks are processed, handle post-processing + end_time = datetime.now() - # After all chunks are processed, handle post-processing - end_time = datetime.now() + await PassThroughStreamingHandler._route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise - await _route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=collected_chunks, - end_time=end_time, + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler + + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes ) + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.VERTEX_AI: + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass - except Exception as e: - verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") - raise + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + Args: + raw_bytes: List of bytes chunks from aiter.bytes() -async def _route_streaming_logging_to_handler( - litellm_logging_obj: LiteLLMLoggingObj, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, - request_body: dict, - endpoint_type: EndpointType, - start_time: datetime, - all_chunks: List[str], - end_time: datetime, -): - """ - Route the logging for the collected chunks to the appropriate handler + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") - Supported endpoint types: - - Anthropic - - Vertex AI - """ - if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 2bd5b790c..fbf37ce8d 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -119,7 +119,6 @@ async def vertex_proxy_route( endpoint: str, request: Request, fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): encoded_endpoint = httpx.URL(endpoint).path @@ -127,6 +126,11 @@ async def vertex_proxy_route( verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} + api_key_to_use = get_litellm_virtual_key(request=request) + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) vertex_project = None vertex_location = None @@ -214,3 +218,18 @@ async def vertex_proxy_route( ) return received_value + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/tests/anthropic_passthrough/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy similarity index 100% rename from tests/anthropic_passthrough/test_anthropic_passthrough.py rename to tests/pass_through_tests/test_anthropic_passthrough_python_sdkpy diff --git a/tests/pass_through_tests/test_gemini.js b/tests/pass_through_tests/test_gemini.js new file mode 100644 index 000000000..2b7d6c5c6 --- /dev/null +++ b/tests/pass_through_tests/test_gemini.js @@ -0,0 +1,23 @@ +// const { GoogleGenerativeAI } = require("@google/generative-ai"); + +// const genAI = new GoogleGenerativeAI("sk-1234"); +// const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" }); + +// const prompt = "Explain how AI works in 2 pages"; + +// async function run() { +// try { +// const result = await model.generateContentStream(prompt, { baseUrl: "http://localhost:4000/gemini" }); +// const response = await result.response; +// console.log(response.text()); +// for await (const chunk of result.stream) { +// const chunkText = chunk.text(); +// console.log(chunkText); +// process.stdout.write(chunkText); +// } +// } catch (error) { +// console.error("Error:", error); +// } +// } + +// run(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_local_vertex.js b/tests/pass_through_tests/test_local_vertex.js new file mode 100644 index 000000000..7ae9b942a --- /dev/null +++ b/tests/pass_through_tests/test_local_vertex.js @@ -0,0 +1,68 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + +const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" +}); + + +// Use customHeaders in RequestOptions +const requestOptions = { + customHeaders: new Headers({ + "x-litellm-api-key": "sk-1234" + }) +}; + +const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions +); + +async function streamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const streamingResult = await generativeModel.generateContentStream(request); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); + } +} + + +async function nonStreamingResponse() { + try { + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + const response = await generativeModel.generateContent(request); + console.log('non streaming response: ', JSON.stringify(response)); + } catch (error) { + console.error('Error:', error); + } +} + + + +streamingResponse(); +nonStreamingResponse(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_vertex.test.js b/tests/pass_through_tests/test_vertex.test.js new file mode 100644 index 000000000..dc457c68a --- /dev/null +++ b/tests/pass_through_tests/test_vertex.test.js @@ -0,0 +1,114 @@ +const { VertexAI, RequestOptions } = require('@google-cloud/vertexai'); +const fs = require('fs'); +const path = require('path'); +const os = require('os'); +const { writeFileSync } = require('fs'); + + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + // Modify the URL to use HTTP instead of HTTPS + if (url.startsWith('https://localhost:4000')) { + url = url.replace('https://', 'http://'); + } + console.log('Patched fetch sending request to:', url); + return originalFetch(url, options); +}; + +function loadVertexAiCredentials() { + console.log("loading vertex ai credentials"); + const filepath = path.dirname(__filename); + const vertexKeyPath = path.join(filepath, "vertex_key.json"); + + // Initialize default empty service account data + let serviceAccountKeyData = {}; + + // Try to read existing vertex_key.json + try { + const content = fs.readFileSync(vertexKeyPath, 'utf8'); + if (content && content.trim()) { + serviceAccountKeyData = JSON.parse(content); + } + } catch (error) { + // File doesn't exist or is invalid, continue with empty object + } + + // Update with environment variables + const privateKeyId = process.env.VERTEX_AI_PRIVATE_KEY_ID || ""; + const privateKey = (process.env.VERTEX_AI_PRIVATE_KEY || "").replace(/\\n/g, "\n"); + + serviceAccountKeyData.private_key_id = privateKeyId; + serviceAccountKeyData.private_key = privateKey; + + // Create temporary file + const tempFilePath = path.join(os.tmpdir(), `vertex-credentials-${Date.now()}.json`); + writeFileSync(tempFilePath, JSON.stringify(serviceAccountKeyData, null, 2)); + + // Set environment variable + process.env.GOOGLE_APPLICATION_CREDENTIALS = tempFilePath; +} + +// Run credential loading before tests +beforeAll(() => { + loadVertexAiCredentials(); +}); + + + +describe('Vertex AI Tests', () => { + test('should successfully generate content from Vertex AI', async () => { + const vertexAI = new VertexAI({ + project: 'adroit-crow-413218', + location: 'us-central1', + apiEndpoint: "localhost:4000/vertex-ai" + }); + + const customHeaders = new Headers({ + "x-litellm-api-key": "sk-1234" + }); + + const requestOptions = { + customHeaders: customHeaders + }; + + const generativeModel = vertexAI.getGenerativeModel( + { model: 'gemini-1.0-pro' }, + requestOptions + ); + + const request = { + contents: [{role: 'user', parts: [{text: 'How are you doing today tell me your name?'}]}], + }; + + const streamingResult = await generativeModel.generateContentStream(request); + + // Add some assertions + expect(streamingResult).toBeDefined(); + + for await (const item of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(item)); + expect(item).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + }); + + + test('should successfully generate non-streaming content from Vertex AI', async () => { + const vertexAI = new VertexAI({project: 'adroit-crow-413218', location: 'us-central1', apiEndpoint: "localhost:4000/vertex-ai"}); + const customHeaders = new Headers({"x-litellm-api-key": "sk-1234"}); + const requestOptions = {customHeaders: customHeaders}; + const generativeModel = vertexAI.getGenerativeModel({model: 'gemini-1.0-pro'}, requestOptions); + const request = {contents: [{role: 'user', parts: [{text: 'What is 2+2?'}]}]}; + + const result = await generativeModel.generateContent(request); + expect(result).toBeDefined(); + expect(result.response).toBeDefined(); + console.log('non-streaming response:', JSON.stringify(result.response)); + }); +}); \ No newline at end of file diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py similarity index 100% rename from tests/pass_through_unit_tests/test_unit_test_anthropic.py rename to tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py new file mode 100644 index 000000000..bbbc465fc --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -0,0 +1,118 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import httpx +import pytest +import litellm +from typing import AsyncGenerator +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy.pass_through_endpoints.types import EndpointType +from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, +) +from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, +) + + +# Helper function to mock async iteration +async def aiter_mock(iterable): + for item in iterable: + yield item + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_type,url_route", + [ + ( + EndpointType.VERTEX_AI, + "v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent", + ), + (EndpointType.ANTHROPIC, "/v1/messages"), + ], +) +async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): + """ + Test that the chunk_processor yields raw bytes + + This is CRITICAL for pass throughs streaming with Vertex AI and Anthropic + """ + # Mock inputs + response = AsyncMock(spec=httpx.Response) + raw_chunks = [ + b'{"id": "1", "content": "Hello"}', + b'{"id": "2", "content": "World"}', + b'\n\ndata: {"id": "3"}', # Testing different byte formats + ] + + # Mock aiter_bytes to return an async generator + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + response.aiter_bytes = mock_aiter_bytes + + request_body = {"key": "value"} + litellm_logging_obj = MagicMock() + start_time = datetime.now() + passthrough_success_handler_obj = MagicMock() + + # Capture yielded chunks and perform detailed assertions + received_chunks = [] + async for chunk in PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + ): + # Assert each chunk is bytes + assert isinstance(chunk, bytes), f"Chunk should be bytes, got {type(chunk)}" + # Assert no decoding/encoding occurred (chunk should be exactly as input) + assert ( + chunk in raw_chunks + ), f"Chunk {chunk} was modified during processing. For pass throughs streaming, chunks should be raw bytes" + received_chunks.append(chunk) + + # Assert all chunks were processed + assert len(received_chunks) == len(raw_chunks), "Not all chunks were processed" + + # collected chunks all together + assert b"".join(received_chunks) == b"".join( + raw_chunks + ), "Collected chunks do not match raw chunks" + + +def test_convert_raw_bytes_to_str_lines(): + """ + Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings + """ + # Test case 1: Single chunk + raw_bytes = [b'data: {"content": "Hello"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}'] + + # Test case 2: Multiple chunks + raw_bytes = [b'data: {"content": "Hello"}\n', b'data: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] + + # Test case 3: Empty input + raw_bytes = [] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == [] + + # Test case 4: Chunks with empty lines + raw_bytes = [b'data: {"content": "Hello"}\n\n', b'\ndata: {"content": "World"}\n'] + result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes) + assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}'] diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py new file mode 100644 index 000000000..a7b668813 --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -0,0 +1,84 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + get_litellm_virtual_key, + vertex_proxy_route, +) + + +@pytest.mark.asyncio +async def test_get_litellm_virtual_key(): + """ + Test that the get_litellm_virtual_key function correctly handles the API key authentication + """ + # Test with x-litellm-api-key + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + # Test with Authorization header + mock_request.headers = {"Authorization": "Bearer auth-key-456"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer auth-key-456" + + # Test with both headers (x-litellm-api-key should take precedence) + mock_request.headers = { + "x-litellm-api-key": "test-key-123", + "Authorization": "Bearer auth-key-456", + } + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_vertex_proxy_route_api_key_auth(): + """ + Critical + + This is how Vertex AI JS SDK will Auth to Litellm Proxy + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123" From 97cde31113db2654310afe169922831bf26be65c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 17:35:38 -0800 Subject: [PATCH 25/78] fix tests (#6875) --- .circleci/config.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1d7ed7602..78bdf3d8e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1179,15 +1179,7 @@ jobs: pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" pip install anthropic - python -m pip install -r requirements.txt # Run pytest and generate JUnit XML report - - run: - name: Run tests - command: | - pwd - ls - python -m pytest -vv tests/pass_through_unit_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 - no_output_timeout: 120m - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . From 772b2f9cd2e8a55e0319117d2e5ff2352b9fa384 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:42:08 -0800 Subject: [PATCH 26/78] Bump cross-spawn from 7.0.3 to 7.0.6 in /ui/litellm-dashboard (#6865) Bumps [cross-spawn](https://github.com/moxystudio/node-cross-spawn) from 7.0.3 to 7.0.6. - [Changelog](https://github.com/moxystudio/node-cross-spawn/blob/master/CHANGELOG.md) - [Commits](https://github.com/moxystudio/node-cross-spawn/compare/v7.0.3...v7.0.6) --- updated-dependencies: - dependency-name: cross-spawn dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- ui/litellm-dashboard/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ui/litellm-dashboard/package-lock.json b/ui/litellm-dashboard/package-lock.json index ee1c9c481..c50c173d8 100644 --- a/ui/litellm-dashboard/package-lock.json +++ b/ui/litellm-dashboard/package-lock.json @@ -1852,9 +1852,9 @@ } }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", From d81ae4582717deb6b18b30c14fa34a1b5ce89e80 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 18:47:26 -0800 Subject: [PATCH 27/78] (Perf / latency improvement) improve pass through endpoint latency to ~50ms (before PR was 400ms) (#6874) * use correct location for types * fix types location * perf improvement for pass through endpoints * update lint check * fix import * fix ensure async clients test * fix azure.py health check * fix ollama --- litellm/llms/AzureOpenAI/azure.py | 3 ++- litellm/llms/custom_httpx/http_handler.py | 3 +-- litellm/llms/custom_httpx/types.py | 11 --------- litellm/llms/ollama.py | 6 ++++- litellm/llms/ollama_chat.py | 6 ++++- .../pass_through_endpoints.py | 9 ++++++-- .../secret_managers/aws_secret_manager_v2.py | 2 +- litellm/types/llms/custom_http.py | 20 ++++++++++++++++ .../ensure_async_clients_test.py | 23 +++++++++++++++++++ 9 files changed, 64 insertions(+), 19 deletions(-) delete mode 100644 litellm/llms/custom_httpx/types.py create mode 100644 litellm/types/llms/custom_http.py diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index f6a1790b6..24303ef2f 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -1528,7 +1528,8 @@ class AzureChatCompletion(BaseLLM): prompt: Optional[str] = None, ) -> dict: client_session = ( - litellm.aclient_session or httpx.AsyncClient() + litellm.aclient_session + or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client ) # handle dall-e-2 calls if "gateway.ai.cloudflare.com" in api_base: diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index f1b78ea63..f5c4f694d 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -8,8 +8,7 @@ from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport import litellm from litellm.caching import InMemoryCache - -from .types import httpxSpecialProvider +from litellm.types.llms.custom_http import * if TYPE_CHECKING: from litellm import LlmProviders diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py deleted file mode 100644 index 8e6ad0eda..000000000 --- a/litellm/llms/custom_httpx/types.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum - -import litellm - - -class httpxSpecialProvider(str, Enum): - LoggingCallback = "logging_callback" - GuardrailCallback = "guardrail_callback" - Caching = "caching" - Oauth2Check = "oauth2_check" - SecretManager = "secret_manager" diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 896b93be5..e9dd2b53f 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -14,6 +14,7 @@ import requests # type: ignore import litellm from litellm import verbose_logger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.secret_managers.main import get_secret_str from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices @@ -456,7 +457,10 @@ def ollama_completion_stream(url, data, logging_obj): async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): try: - client = httpx.AsyncClient() + _async_http_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.OLLAMA + ) + client = _async_http_client.client async with client.stream( url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout ) as response: diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 536f766e0..ce0df139d 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -13,6 +13,7 @@ from pydantic import BaseModel import litellm from litellm import verbose_logger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.utils import StreamingChoices @@ -445,7 +446,10 @@ async def ollama_async_streaming( url, api_key, data, model_response, encoding, logging_obj ): try: - client = httpx.AsyncClient() + _async_http_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.OLLAMA + ) + client = _async_http_client.client _request = { "url": f"{url}", "json": data, diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index f60fd0166..0fd174440 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -22,6 +22,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator, ) @@ -35,6 +36,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.custom_http import httpxSpecialProvider from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging @@ -363,8 +365,11 @@ async def pass_through_request( # noqa: PLR0915 data=_parsed_body, call_type="pass_through_endpoint", ) - - async_client = httpx.AsyncClient(timeout=600) + async_client_obj = get_async_httpx_client( + llm_provider=httpxSpecialProvider.PassThroughEndpoint, + params={"timeout": 600}, + ) + async_client = async_client_obj.client litellm_call_id = str(uuid.uuid4()) diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index 69add6f23..32653f57d 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -31,8 +31,8 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.llms.custom_httpx.types import httpxSpecialProvider from litellm.proxy._types import KeyManagementSystem +from litellm.types.llms.custom_http import httpxSpecialProvider class AWSSecretsManagerV2(BaseAWSLLM): diff --git a/litellm/types/llms/custom_http.py b/litellm/types/llms/custom_http.py new file mode 100644 index 000000000..f43daff2a --- /dev/null +++ b/litellm/types/llms/custom_http.py @@ -0,0 +1,20 @@ +from enum import Enum + +import litellm + + +class httpxSpecialProvider(str, Enum): + """ + Httpx Clients can be created for these litellm internal providers + + Example: + - langsmith logging would need a custom async httpx client + - pass through endpoint would need a custom async httpx client + """ + + LoggingCallback = "logging_callback" + GuardrailCallback = "guardrail_callback" + Caching = "caching" + Oauth2Check = "oauth2_check" + SecretManager = "secret_manager" + PassThroughEndpoint = "pass_through_endpoint" diff --git a/tests/code_coverage_tests/ensure_async_clients_test.py b/tests/code_coverage_tests/ensure_async_clients_test.py index a509e5509..0565de9b3 100644 --- a/tests/code_coverage_tests/ensure_async_clients_test.py +++ b/tests/code_coverage_tests/ensure_async_clients_test.py @@ -5,9 +5,19 @@ ALLOWED_FILES = [ # local files "../../litellm/__init__.py", "../../litellm/llms/custom_httpx/http_handler.py", + "../../litellm/router_utils/client_initalization_utils.py", + "../../litellm/llms/custom_httpx/http_handler.py", + "../../litellm/llms/huggingface_restapi.py", + "../../litellm/llms/base.py", + "../../litellm/llms/custom_httpx/httpx_handler.py", # when running on ci/cd "./litellm/__init__.py", "./litellm/llms/custom_httpx/http_handler.py", + "./litellm/router_utils/client_initalization_utils.py", + "./litellm/llms/custom_httpx/http_handler.py", + "./litellm/llms/huggingface_restapi.py", + "./litellm/llms/base.py", + "./litellm/llms/custom_httpx/httpx_handler.py", ] warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request" @@ -43,6 +53,19 @@ def check_for_async_http_handler(file_path): raise ValueError( f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}" ) + # Check for attribute calls like httpx.AsyncClient() + elif isinstance(node.func, ast.Attribute): + full_name = "" + current = node.func + while isinstance(current, ast.Attribute): + full_name = "." + current.attr + full_name + current = current.value + if isinstance(current, ast.Name): + full_name = current.id + full_name + if full_name.lower() in [name.lower() for name in target_names]: + raise ValueError( + f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}" + ) return violations From 7e9d8b58f6e9f5c622513f22a26d5952427af8c9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 23 Nov 2024 15:17:40 +0530 Subject: [PATCH 28/78] LiteLLM Minor Fixes & Improvements (11/23/2024) (#6870) * feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. * fix(utils.py): allow disabling end user cost tracking with new param Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small * docs(configs.md): add disable_end_user_cost_tracking reference to docs * feat(key_management_endpoints.py): add support for restricting access to `/key/generate` by team/proxy level role Enables admin to restrict key creation, and assign team admins to handle distributing keys * test(test_key_management.py): add unit testing for personal / team key restriction checks * docs: add docs on restricting key creation * docs(finetuned_models.md): add new guide on calling finetuned models * docs(input.md): cleanup anthropic supported params Closes https://github.com/BerriAI/litellm/issues/6856 * test(test_embedding.py): add test for passing extra headers via embedding * feat(cohere/embed): pass client to async embedding * feat(rerank.py): add `/v1/rerank` if missing for cohere base url Closes https://github.com/BerriAI/litellm/issues/6844 * fix(main.py): pass extra_headers param to openai Fixes https://github.com/BerriAI/litellm/issues/6836 * fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically * fix(handler.py): fix linting error * fix: fix typing * build: add conftest to proxy_admin_ui_tests/ * test: fix test * fix: fix linting errors * test: fix test * fix: fix pass through testing --- docs/my-website/docs/completion/input.md | 2 +- .../docs/guides/finetuned_models.md | 74 ++++++++++++++ docs/my-website/docs/proxy/configs.md | 2 + docs/my-website/docs/proxy/self_serve.md | 8 +- docs/my-website/docs/proxy/virtual_keys.md | 69 +++++++++++++ docs/my-website/sidebars.js | 42 ++++---- litellm/__init__.py | 3 + litellm/integrations/prometheus.py | 10 +- litellm/litellm_core_utils/litellm_logging.py | 88 ++++++----------- litellm/llms/azure_ai/rerank/handler.py | 2 + litellm/llms/cohere/embed/handler.py | 6 ++ litellm/llms/cohere/rerank.py | 37 ++++++- litellm/main.py | 4 + litellm/proxy/_new_secret_config.yaml | 4 +- litellm/proxy/_types.py | 72 +++++++++++--- .../key_management_endpoints.py | 73 ++++++++++++++ .../organization_endpoints.py | 4 +- .../anthropic_passthrough_logging_handler.py | 39 ++++---- .../vertex_passthrough_logging_handler.py | 55 ++++++----- .../streaming_handler.py | 68 ++++++++++--- .../pass_through_endpoints/success_handler.py | 97 +++++++++++-------- litellm/proxy/proxy_server.py | 4 +- litellm/proxy/utils.py | 15 ++- litellm/rerank_api/main.py | 4 +- litellm/types/utils.py | 13 +++ litellm/utils.py | 10 ++ tests/local_testing/test_embedding.py | 31 ++++++ tests/local_testing/test_rerank.py | 34 ++++++- tests/local_testing/test_utils.py | 20 ++++ .../test_unit_tests_init_callbacks.py | 75 ++++++++++++++ .../test_unit_test_anthropic_pass_through.py | 27 +----- .../test_unit_test_streaming.py | 1 + tests/proxy_admin_ui_tests/conftest.py | 54 +++++++++++ .../test_key_management.py | 62 ++++++++++++ .../test_role_based_access.py | 10 +- 35 files changed, 871 insertions(+), 248 deletions(-) create mode 100644 docs/my-website/docs/guides/finetuned_models.md create mode 100644 tests/proxy_admin_ui_tests/conftest.py diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index c563a5bf0..e55c160e0 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea | Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| -|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | diff --git a/docs/my-website/docs/guides/finetuned_models.md b/docs/my-website/docs/guides/finetuned_models.md new file mode 100644 index 000000000..cb0d49b44 --- /dev/null +++ b/docs/my-website/docs/guides/finetuned_models.md @@ -0,0 +1,74 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + +# Calling Finetuned Models + +## OpenAI + + +| Model Name | Function Call | +|---------------------------|-----------------------------------------------------------------| +| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` | +| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` | +| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` | + + +## Vertex AI + +Fine tuned models on vertex have a numerical model/endpoint id. + + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/", # e.g. vertex_ai/4965075652664360960 + messages=[{ "content": "Hello, how are you?","role": "user"}], + base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing +) +``` + + + + +1. Add Vertex Credentials to your env + +```bash +!gcloud auth application-default login +``` + +2. Setup config.yaml + +```yaml +- model_name: finetuned-gemini + litellm_params: + model: vertex_ai/ + vertex_project: + vertex_location: + model_info: + base_model: vertex_ai/gemini-1.5-pro # IMPORTANT +``` + +3. Test it! + +```bash +curl --location 'https://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: ' \ +--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}' +``` + + + + + diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 3b6b336d6..df22a29e3 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -754,6 +754,8 @@ general_settings: | cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. | | cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | | cache_params.mode | string | The mode of the cache. [Further docs](./caching) | +| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | +| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) | ### general_settings - Reference diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index e04aa4b44..494d9e60d 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -217,4 +217,10 @@ litellm_settings: max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. -``` \ No newline at end of file + + key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 3b9a2a03e..98b06d33b 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -811,6 +811,75 @@ litellm_settings: team_id: "core-infra" ``` +### Restricting Key Generation + +Use this to control who can generate keys. Useful when letting others create keys on the UI. + +```yaml +litellm_settings: + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` + +#### Spec + +```python +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[LitellmUserRoles] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig + + +class LitellmUserRoles(str, enum.Enum): + """ + Admin Roles: + PROXY_ADMIN: admin over the platform + PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend + ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization + + Internal User Roles: + INTERNAL_USER: can login, view/create/delete their own keys, view their spend + INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend + + + Team Roles: + TEAM: used for JWT auth + + + Customer Roles: + CUSTOMER: External users -> these are customers + + """ + + # Admin Roles + PROXY_ADMIN = "proxy_admin" + PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer" + + # Organization admins + ORG_ADMIN = "org_admin" + + # Internal User Roles + INTERNAL_USER = "internal_user" + INTERNAL_USER_VIEW_ONLY = "internal_user_viewer" + + # Team Roles + TEAM = "team" + + # Customer Roles - External users of proxy + CUSTOMER = "customer" +``` + + ## **Next Steps - Set Budgets, Rate Limits per Virtual Key** [Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f01402299..f2bb1c5e9 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -199,6 +199,31 @@ const sidebars = { ], }, + { + type: "category", + label: "Guides", + items: [ + "exception_mapping", + "completion/provider_specific_params", + "guides/finetuned_models", + "completion/audio", + "completion/vision", + "completion/json_mode", + "completion/prompt_caching", + "completion/predict_outputs", + "completion/prefix", + "completion/drop_params", + "completion/prompt_formatting", + "completion/stream", + "completion/message_trimming", + "completion/function_call", + "completion/model_alias", + "completion/batching", + "completion/mock_requests", + "completion/reliable_completions", + + ] + }, { type: "category", label: "Supported Endpoints", @@ -214,25 +239,8 @@ const sidebars = { }, items: [ "completion/input", - "completion/provider_specific_params", - "completion/json_mode", - "completion/prompt_caching", - "completion/audio", - "completion/vision", - "completion/predict_outputs", - "completion/prefix", - "completion/drop_params", - "completion/prompt_formatting", "completion/output", "completion/usage", - "exception_mapping", - "completion/stream", - "completion/message_trimming", - "completion/function_call", - "completion/model_alias", - "completion/batching", - "completion/mock_requests", - "completion/reliable_completions", ], }, "embedding/supported_embedding", diff --git a/litellm/__init__.py b/litellm/__init__.py index c978b24ee..65b1b3465 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -24,6 +24,7 @@ from litellm.proxy._types import ( KeyManagementSettings, LiteLLM_UpperboundKeyGenerateParams, ) +from litellm.types.utils import StandardKeyGenerationConfig import httpx import dotenv from enum import Enum @@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None +key_generation_settings: Optional[StandardKeyGenerationConfig] = None default_internal_user_params: Optional[Dict] = None default_team_settings: Optional[List] = None max_user_budget: Optional[float] = None @@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None max_internal_user_budget: Optional[float] = None internal_user_budget_duration: Optional[str] = None max_end_user_budget: Optional[float] = None +disable_end_user_cost_tracking: Optional[bool] = None #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bb28719a3..1460a1d7f 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking class PrometheusLogger(CustomLogger): @@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger): model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger): # unpack kwargs model = kwargs.get("model", "") - litellm_params = kwargs.get("litellm_params", {}) or {} standard_logging_payload: StandardLoggingPayload = kwargs.get( "standard_logging_object", {} ) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - - end_user_id = proxy_server_request.get("body", {}).get("user", None) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 69d6adca4..298e28974 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -934,19 +934,10 @@ class Logging: status="success", ) ) - if self.dynamic_success_callbacks is not None and isinstance( - self.dynamic_success_callbacks, list - ): - callbacks = self.dynamic_success_callbacks - ## keep the internal functions ## - for callback in litellm.success_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) ## REDACT MESSAGES ## result = redact_message_input_output_from_logging( @@ -1368,8 +1359,11 @@ class Logging: and customLogger is not None ): # custom logger functions print_verbose( - "success callbacks: Running Custom Callback Function" + "success callbacks: Running Custom Callback Function - {}".format( + callback + ) ) + customLogger.log_event( kwargs=self.model_call_details, response_obj=result, @@ -1466,21 +1460,10 @@ class Logging: status="success", ) ) - if self.dynamic_async_success_callbacks is not None and isinstance( - self.dynamic_async_success_callbacks, list - ): - callbacks = self.dynamic_async_success_callbacks - ## keep the internal functions ## - for callback in litellm._async_success_callback: - callback_name = "" - if isinstance(callback, CustomLogger): - callback_name = callback.__class__.__name__ - if callable(callback): - callback_name = callback.__name__ - if "_PROXY_" in callback_name: - callbacks.append(callback) - else: - callbacks = litellm._async_success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_success_callbacks, + global_callbacks=litellm._async_success_callback, + ) result = redact_message_input_output_from_logging( model_call_details=( @@ -1747,21 +1730,10 @@ class Logging: start_time=start_time, end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_failure_callbacks is not None and isinstance( - self.dynamic_failure_callbacks, list - ): - callbacks = self.dynamic_failure_callbacks - ## keep the internal functions ## - for callback in litellm.failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_failure_callbacks, + global_callbacks=litellm.failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created @@ -1944,21 +1916,10 @@ class Logging: end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_async_failure_callbacks is not None and isinstance( - self.dynamic_async_failure_callbacks, list - ): - callbacks = self.dynamic_async_failure_callbacks - ## keep the internal functions ## - for callback in litellm._async_failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm._async_failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_failure_callbacks, + global_callbacks=litellm._async_failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created for callback in callbacks: @@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_mlflow_logger) return _mlflow_logger # type: ignore + def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, ) -> Optional[CustomLogger]: @@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params): if integration_name == "supabase": if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] + + +def get_combined_callback_list( + dynamic_success_callbacks: Optional[List], global_callbacks: List +) -> List: + if dynamic_success_callbacks is None: + return global_callbacks + return list(set(dynamic_success_callbacks + global_callbacks)) diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py index a67c893f2..60edfd296 100644 --- a/litellm/llms/azure_ai/rerank/handler.py +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -4,6 +4,7 @@ import httpx from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.rerank import RerankResponse @@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: if headers is None: diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 5b224c375..afeba10b5 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -74,6 +74,7 @@ async def async_embedding( }, ) ## COMPLETION CALL + if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.COHERE, @@ -151,6 +152,11 @@ def embedding( api_key=api_key, headers=headers, encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) ## LOGGING diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 022ffc6f9..8de2dfbb4 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs from typing import Any, Dict, List, Optional, Union +import httpx + import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, _get_httpx_client, get_async_httpx_client, ) @@ -34,6 +38,23 @@ class CohereRerank(BaseLLM): # Merge other headers, overriding any default ones except Authorization return {**default_headers, **headers} + def ensure_rerank_endpoint(self, api_base: str) -> str: + """ + Ensures the `/v1/rerank` endpoint is appended to the given `api_base`. + If `/v1/rerank` is already present, the original URL is returned. + + :param api_base: The base API URL. + :return: A URL with `/v1/rerank` appended if missing. + """ + # Parse the base URL to ensure proper structure + url = httpx.URL(api_base) + + # Check if the URL already ends with `/v1/rerank` + if not url.path.endswith("/v1/rerank"): + url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank") + + return str(url) + def rerank( self, model: str, @@ -48,9 +69,10 @@ class CohereRerank(BaseLLM): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, # New parameter + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: headers = self.validate_environment(api_key=api_key, headers=headers) - + api_base = self.ensure_rerank_endpoint(api_base) request_data = RerankRequest( model=model, query=query, @@ -76,9 +98,13 @@ class CohereRerank(BaseLLM): if _is_async: return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method - client = _get_httpx_client() + if client is not None and isinstance(client, HTTPHandler): + client = client + else: + client = _get_httpx_client() + response = client.post( - api_base, + url=api_base, headers=headers, json=request_data_dict, ) @@ -100,10 +126,13 @@ class CohereRerank(BaseLLM): api_key: str, api_base: str, headers: dict, + client: Optional[AsyncHTTPHandler] = None, ) -> RerankResponse: request_data_dict = request_data.dict(exclude_none=True) - client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) + client = client or get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE + ) response = await client.post( api_base, diff --git a/litellm/main.py b/litellm/main.py index 5d433eb36..5095ce518 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915 or litellm.openai_key or get_secret_str("OPENAI_API_KEY") ) + + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + api_type = "openai" api_version = None diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ce9bd1d2f..7baf2224c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -16,7 +16,7 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ - + router_settings: model_group_alias: "gpt-4-turbo": # Aliased model name @@ -35,4 +35,4 @@ litellm_settings: failure_callback: ["langfuse"] langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 - langfuse_host: https://us.cloud.langfuse.com \ No newline at end of file + langfuse_host: https://us.cloud.langfuse.com diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8b8dbf2e5..74e82b0ea 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2,6 +2,7 @@ import enum import json import os import sys +import traceback import uuid from dataclasses import fields from datetime import datetime @@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict from litellm.types.integrations.slack_alerting import AlertType from litellm.types.router import RouterErrors, UpdateRouterConfig -from litellm.types.utils import ProviderField, StandardCallbackDynamicParams +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ProviderField, + StandardCallbackDynamicParams, + StandardPassThroughResponseObject, + TextCompletionResponse, +) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase): user_ids: List[str] -class Member(LiteLLMBase): - role: Literal[ - LitellmUserRoles.ORG_ADMIN, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - # older Member roles - "admin", - "user", - ] +class MemberBase(LiteLLMBase): user_id: Optional[str] = None user_email: Optional[str] = None @@ -904,6 +905,21 @@ class Member(LiteLLMBase): return values +class Member(MemberBase): + role: Literal[ + "admin", + "user", + ] + + +class OrgMember(MemberBase): + role: Literal[ + LitellmUserRoles.ORG_ADMIN, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] + + class TeamBase(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None @@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase): # Replace member_data with the single Member object data["member"] = member # Call the superclass __init__ method to initialize the object + traceback.print_stack() + super().__init__(**data) + + +class OrgMemberAddRequest(LiteLLMBase): + member: Union[List[OrgMember], OrgMember] + + def __init__(self, **data): + member_data = data.get("member") + if isinstance(member_data, list): + # If member is a list of dictionaries, convert each dictionary to a Member object + members = [OrgMember(**item) for item in member_data] + # Replace member_data with the list of Member objects + data["member"] = members + elif isinstance(member_data, dict): + # If member is a dictionary, convert it to a single Member object + member = OrgMember(**member_data) + # Replace member_data with the single Member object + data["member"] = member + # Call the superclass __init__ method to initialize the object super().__init__(**data) @@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse): # Organization Member Requests -class OrganizationMemberAddRequest(MemberAddRequest): +class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str max_budget_in_organization: Optional[float] = ( None # Users max budget within the organization @@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum): spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used.""" team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None.""" duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None.""" + + +PassThroughEndpointLoggingResultValues = Union[ + ModelResponse, + TextCompletionResponse, + ImageResponse, + EmbeddingResponse, + StandardPassThroughResponseObject, +] + + +class PassThroughEndpointLoggingTypedDict(TypedDict): + result: Optional[PassThroughEndpointLoggingResultValues] + kwargs: dict diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index e4493a28c..ab13616d5 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -40,6 +40,77 @@ from litellm.proxy.utils import ( ) from litellm.secret_managers.main import get_secret + +def _is_team_key(data: GenerateKeyRequest): + return data.team_id is not None + + +def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("team_key_generation") is None + ): + return True + + if user_api_key_dict.team_member is None: + raise HTTPException( + status_code=400, + detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}", + ) + + team_member_role = user_api_key_dict.team_member.role + if ( + team_member_role + not in litellm.key_generation_settings["team_key_generation"][ # type: ignore + "allowed_team_member_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore + ) + return True + + +def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("personal_key_generation") is None + ): + return True + + if ( + user_api_key_dict.user_role + not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore + "allowed_user_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore + ) + return True + + +def key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +) -> bool: + """ + Check if admin has restricted key creation to certain roles for teams or individuals + """ + if litellm.key_generation_settings is None: + return True + + ## check if key is for team or individual + is_team_key = _is_team_key(data=data) + + if is_team_key: + return _team_key_generation_check(user_api_key_dict) + else: + return _personal_key_generation_check(user_api_key_dict=user_api_key_dict) + + router = APIRouter() @@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915 raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=message ) + elif litellm.key_generation_settings is not None: + key_generation_check(user_api_key_dict=user_api_key_dict, data=data) # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index 81d135097..363384375 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -352,7 +352,7 @@ async def organization_member_add( }, ) - members: List[Member] + members: List[OrgMember] if isinstance(data.member, List): members = data.member else: @@ -397,7 +397,7 @@ async def organization_member_add( async def add_member_to_organization( - member: Member, + member: OrgMember, organization_id: str, prisma_client: PrismaClient, ) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index ad5a98258..d155174a7 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import ( ModelResponseIterator as AnthropicModelResponseIterator, ) from litellm.llms.anthropic.chat.transformation import AnthropicConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -26,7 +27,7 @@ else: class AnthropicPassthroughLoggingHandler: @staticmethod - async def anthropic_passthrough_handler( + def anthropic_passthrough_handler( httpx_response: httpx.Response, response_body: dict, logging_obj: LiteLLMLoggingObj, @@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled """ @@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler: logging_obj=logging_obj, ) - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) - - pass + return { + "result": litellm_model_response, + "kwargs": kwargs, + } @staticmethod def _create_anthropic_response_logging_payload( @@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler: return kwargs @staticmethod - async def _handle_logging_anthropic_collected_chunks( + def _handle_logging_anthropic_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks @@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." ) - return + return { + "result": None, + "kwargs": {}, + } kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( litellm_model_response=complete_streaming_response, model=model, @@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler: end_time=end_time, logging_obj=litellm_logging_obj, ) - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 275a0a119..2773979ad 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexModelResponseIterator, ) +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -25,7 +26,7 @@ else: class VertexPassthroughLoggingHandler: @staticmethod - async def vertex_passthrough_handler( + def vertex_passthrough_handler( httpx_response: httpx.Response, logging_obj: LiteLLMLoggingObj, url_route: str, @@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: if "generateContent" in url_route: model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) @@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler: logging_obj=logging_obj, ) - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + elif "predict" in url_route: from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( VertexImageGeneration, @@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler: logging_obj.model = model logging_obj.model_call_details["model"] = logging_obj.model - await logging_obj.async_success_handler( - result=litellm_prediction_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_prediction_response, + "kwargs": kwargs, + } + else: + return { + "result": None, + "kwargs": kwargs, + } @staticmethod - async def _handle_logging_vertex_collected_chunks( + def _handle_logging_vertex_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks @@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) - return + return { + "result": None, + "kwargs": kwargs, + } + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( litellm_model_response=complete_streaming_response, model=model, @@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler: end_time=end_time, logging_obj=litellm_logging_obj, ) - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 522319aaa..dc6aae3af 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -1,5 +1,6 @@ import asyncio import json +import threading from datetime import datetime from enum import Enum from typing import AsyncIterable, Dict, List, Optional, Union @@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexAIIterator, ) -from litellm.types.utils import GenericStreamingChunk +from litellm.proxy._types import PassThroughEndpointLoggingResultValues +from litellm.types.utils import ( + GenericStreamingChunk, + ModelResponse, + StandardPassThroughResponseObject, +) from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, @@ -87,8 +93,12 @@ class PassThroughStreamingHandler: all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( raw_bytes ) + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + kwargs: dict = {} if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, @@ -98,20 +108,48 @@ class PassThroughStreamingHandler: all_chunks=all_chunks, end_time=end_time, ) + standard_logging_response_object = anthropic_passthrough_logging_handler_result[ + "result" + ] + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + standard_logging_response_object = vertex_passthrough_logging_handler_result[ + "result" + ] + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" + ) + threading.Thread( + target=litellm_logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + False, + ), + ).start() + await litellm_logging_obj.async_success_handler( + result=standard_logging_response_object, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) @staticmethod def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: @@ -130,4 +168,4 @@ class PassThroughStreamingHandler: # Split by newlines and filter out empty lines lines = [line.strip() for line in combined_str.split("\n") if line.strip()] - return lines + return lines \ No newline at end of file diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index e22a37052..c9c7707f0 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) +from litellm.proxy._types import PassThroughEndpointLoggingResultValues from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.utils import StandardPassThroughResponseObject @@ -49,53 +50,69 @@ class PassThroughEndpointLogging: cache_hit: bool, **kwargs, ): + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None if self.is_vertex_route(url_route): - await VertexPassthroughLoggingHandler.vertex_passthrough_handler( - httpx_response=httpx_response, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler.vertex_passthrough_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] elif self.is_anthropic_route(url_route): - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( - httpx_response=httpx_response, - response_body=response_body or {}, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + anthropic_passthrough_logging_handler_result = ( + AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) - else: + + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + if standard_logging_response_object is None: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) - threading.Thread( - target=logging_obj.success_handler, - args=( - standard_logging_response_object, - start_time, - end_time, - cache_hit, - ), - ).start() - await logging_obj.async_success_handler( - result=( - json.dumps(result) - if isinstance(result, dict) - else standard_logging_response_object - ), - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + threading.Thread( + target=logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + cache_hit, + ), + ).start() + await logging_obj.async_success_handler( + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) def is_vertex_route(self, url_route: str): for route in self.TRACKED_VERTEX_ROUTES: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9d7c120a7..70bf5b523 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import ( from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import RouterGeneralSettings from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking try: from litellm._version import version @@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback( ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) user_id = metadata.get("user_api_key_user_id", None) team_id = metadata.get("user_api_key_team_id", None) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 74bf398e7..0f7d6c3e0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -337,14 +337,14 @@ class ProxyLogging: alert_to_webhook_url=self.alert_to_webhook_url, ) - if ( - self.alerting is not None - and "slack" in self.alerting - and "daily_reports" in self.alert_types - ): + if self.alerting is not None and "slack" in self.alerting: # NOTE: ENSURE we only add callbacks when alerting is on # We should NOT add callbacks when alerting is off - litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + if "daily_reports" in self.alert_types: + litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + litellm.success_callback.append( + self.slack_alerting_instance.response_taking_too_long_callback + ) if redis_cache is not None: self.internal_usage_cache.dual_cache.redis_cache = redis_cache @@ -354,9 +354,6 @@ class ProxyLogging: litellm.callbacks.append(self.max_budget_limiter) # type: ignore litellm.callbacks.append(self.cache_control_check) # type: ignore litellm.callbacks.append(self.service_logging_obj) # type: ignore - litellm.success_callback.append( - self.slack_alerting_instance.response_taking_too_long_callback - ) for callback in litellm.callbacks: if isinstance(callback, str): callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 9cc8a8c1d..7e6dc7503 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -91,6 +91,7 @@ def rerank( model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) user = kwargs.get("user", None) + client = kwargs.get("client", None) try: _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) @@ -150,7 +151,7 @@ def rerank( or optional_params.api_base or litellm.api_base or get_secret("COHERE_API_BASE") # type: ignore - or "https://api.cohere.com/v1/rerank" + or "https://api.cohere.com" ) if api_base is None: @@ -173,6 +174,7 @@ def rerank( _is_async=_is_async, headers=headers, litellm_logging_obj=litellm_logging_obj, + client=client, ) elif _custom_llm_provider == "azure_ai": api_base = ( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d02129681..334894320 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False): langsmith_api_key: Optional[str] langsmith_project: Optional[str] langsmith_base_url: Optional[str] + + +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[str] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig diff --git a/litellm/utils.py b/litellm/utils.py index 003971142..262af3418 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6170,3 +6170,13 @@ class ProviderConfigManager: return litellm.GroqChatConfig() return OpenAIGPTConfig() + + +def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]: + """ + Used for enforcing `disable_end_user_cost_tracking` param. + """ + proxy_server_request = litellm_params.get("proxy_server_request") or {} + if litellm.disable_end_user_cost_tracking: + return None + return proxy_server_request.get("body", {}).get("user", None) diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index d7988e690..096dfc419 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.image_tokens > 0 else: assert response.usage.prompt_tokens_details.text_tokens > 0 + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_embedding_with_extra_headers(sync_mode): + + input = ["hello world"] + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler + + if sync_mode: + client = HTTPHandler() + else: + client = AsyncHTTPHandler() + + data = { + "model": "cohere/embed-english-v3.0", + "input": input, + "extra_headers": {"my-test-param": "hello-world"}, + "client": client, + } + with patch.object(client, "post") as mock_post: + try: + if sync_mode: + embedding(**data) + else: + await litellm.aembedding(**data) + except Exception as e: + print(e) + + mock_post.assert_called_once() + assert "my-test-param" in mock_post.call_args.kwargs["headers"] diff --git a/tests/local_testing/test_rerank.py b/tests/local_testing/test_rerank.py index c5ed1efe5..5fca6f135 100644 --- a/tests/local_testing/test_rerank.py +++ b/tests/local_testing/test_rerank.py @@ -215,7 +215,10 @@ async def test_rerank_custom_api_base(): args_to_api = kwargs["json"] print("Arguments passed to API=", args_to_api) print("url = ", _url) - assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/" + assert ( + _url[0] + == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" + ) assert args_to_api == expected_payload assert response.id is not None assert response.results is not None @@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks(): assert custom_logger.kwargs.get("response_cost") > 0.0 assert custom_logger.response_obj is not None assert custom_logger.response_obj.results is not None + + +def test_complete_base_url_cohere(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + litellm.api_base = "http://localhost:4000" + litellm.set_verbose = True + + text = "Hello there!" + list_texts = ["Hello there!", "How are you?", "How do you do?"] + + rerank_model = "rerank-multilingual-v3.0" + + with patch.object(client, "post") as mock_post: + try: + litellm.rerank( + model=rerank_model, + query=text, + documents=list_texts, + custom_llm_provider="cohere", + client=client, + ) + except Exception as e: + print(e) + + print("mock_post.call_args", mock_post.call_args) + mock_post.assert_called_once() + assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 52946ca30..cf1db27e8 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1012,3 +1012,23 @@ def test_models_by_provider(): for provider in providers: assert provider in models_by_provider.keys() + + +@pytest.mark.parametrize( + "litellm_params, disable_end_user_cost_tracking, expected_end_user_id", + [ + ({}, False, None), + ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), + ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ], +) +def test_get_end_user_id_for_cost_tracking( + litellm_params, disable_end_user_cost_tracking, expected_end_user_id +): + from litellm.utils import get_end_user_id_for_cost_tracking + + litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking + assert ( + get_end_user_id_for_cost_tracking(litellm_params=litellm_params) + == expected_end_user_id + ) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 38883fa38..15c2118d8 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback(): await use_callback_in_llm_call(callback, used_in="success_callback") reset_env_vars() + + +def test_dynamic_logging_global_callback(): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message, Usage + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + kwargs={ + "langfuse_public_key": "my-mock-public-key", + "langfuse_secret_key": "my-mock-secret-key", + }, + dynamic_success_callbacks=["langfuse"], + ) + + with patch.object(cl, "log_success_event") as mock_log_success_event: + cl.log_success_event = mock_log_success_event + litellm.success_callback = [cl] + + try: + litellm_logging.success_handler( + result=ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + system_fingerprint=None, + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + usage=Usage( + completion_tokens=20, + prompt_tokens=10, + total_tokens=30, + completion_tokens_details=None, + prompt_tokens_details=None, + ), + ), + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + except Exception as e: + print(f"Error: {e}") + + mock_log_success_event.assert_called_once() + + +def test_get_combined_callback_list(): + from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list + + assert "langfuse" in get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) + assert "lago" in get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py index afb77f718..ecd289005 100644 --- a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py @@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler( start_time = datetime.now() end_time = datetime.now() - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( httpx_response=mock_httpx_response, response_body=mock_response, logging_obj=mock_logging_obj, @@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler( cache_hit=False, ) - # Assert that async_success_handler was called - assert mock_logging_obj.async_success_handler.called - - call_args = mock_logging_obj.async_success_handler.call_args - call_kwargs = call_args.kwargs - print("call_kwargs", call_kwargs) - - # Assert required fields are present in call_kwargs - assert "result" in call_kwargs - assert "start_time" in call_kwargs - assert "end_time" in call_kwargs - assert "cache_hit" in call_kwargs - assert "response_cost" in call_kwargs - assert "model" in call_kwargs - assert "standard_logging_object" in call_kwargs - - # Assert specific values and types - assert isinstance(call_kwargs["result"], litellm.ModelResponse) - assert isinstance(call_kwargs["start_time"], datetime) - assert isinstance(call_kwargs["end_time"], datetime) - assert isinstance(call_kwargs["cache_hit"], bool) - assert isinstance(call_kwargs["response_cost"], float) - assert call_kwargs["model"] == "claude-3-opus-20240229" - assert isinstance(call_kwargs["standard_logging_object"], dict) + assert isinstance(result["result"], litellm.ModelResponse) def test_create_anthropic_response_logging_payload(mock_logging_obj): diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py index bbbc465fc..61b71b56d 100644 --- a/tests/pass_through_unit_tests/test_unit_test_streaming.py +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): litellm_logging_obj = MagicMock() start_time = datetime.now() passthrough_success_handler_obj = MagicMock() + litellm_logging_obj.async_success_handler = AsyncMock() # Capture yielded chunks and perform detailed assertions received_chunks = [] diff --git a/tests/proxy_admin_ui_tests/conftest.py b/tests/proxy_admin_ui_tests/conftest.py new file mode 100644 index 000000000..eca0bc431 --- /dev/null +++ b/tests/proxy_admin_ui_tests/conftest.py @@ -0,0 +1,54 @@ +# conftest.py + +import importlib +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(): + """ + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + """ + curr_dir = os.getcwd() # Get the current working directory + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path + + import litellm + from litellm import Router + + importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + print(litellm) + # from litellm import Router, completion, aembedding, acompletion, embedding + yield + + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + + +def pytest_collection_modifyitems(config, items): + # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] + + # Sort tests based on their names + custom_logger_tests.sort(key=lambda x: x.name) + other_tests.sort(key=lambda x: x.name) + + # Reorder the items list + items[:] = custom_logger_tests + other_tests diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index b039a101b..81d9fb676 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -542,3 +542,65 @@ async def test_list_teams(prisma_client): # Clean up await prisma_client.delete_data(team_id_list=[team_id], table_name="team") + + +def test_is_team_key(): + from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key + + assert _is_team_key(GenerateKeyRequest(team_id="test_team_id")) + assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id")) + + +def test_team_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _team_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "team_key_generation": {"allowed_team_member_roles": ["admin"]} + } + + assert _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + team_member=Member(role="admin", user_id="test_user_id"), + ) + ) + + with pytest.raises(HTTPException): + _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="test_user_id", + team_member=Member(role="user", user_id="test_user_id"), + ) + ) + + +def test_personal_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _personal_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "personal_key_generation": {"allowed_user_roles": ["proxy_admin"]} + } + + assert _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" + ) + ) + + with pytest.raises(HTTPException): + _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="admin", + ) + ) diff --git a/tests/proxy_admin_ui_tests/test_role_based_access.py b/tests/proxy_admin_ui_tests/test_role_based_access.py index 609a3598d..ff73143bf 100644 --- a/tests/proxy_admin_ui_tests/test_role_based_access.py +++ b/tests/proxy_admin_ui_tests/test_role_based_access.py @@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=user_role, user_id=created_user_id), + member=OrgMember(role=user_role, user_id=created_user_id), ), http_request=None, ) @@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member( + member=OrgMember( role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org ), ), @@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org1_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) From a8b4e1cc0393ee2ad490f091e074c518052ac935 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 08:34:55 -0800 Subject: [PATCH 29/78] fix playwright e2e ui test --- tests/proxy_admin_ui_tests/playwright.config.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/proxy_admin_ui_tests/playwright.config.ts b/tests/proxy_admin_ui_tests/playwright.config.ts index c77897a02..ba15d5458 100644 --- a/tests/proxy_admin_ui_tests/playwright.config.ts +++ b/tests/proxy_admin_ui_tests/playwright.config.ts @@ -13,6 +13,7 @@ import { defineConfig, devices } from '@playwright/test'; */ export default defineConfig({ testDir: './e2e_ui_tests', + testIgnore: '**/tests/pass_through_tests/**', /* Run tests in files in parallel */ fullyParallel: true, /* Fail the build on CI if you accidentally left test.only in the source code. */ From fb5f4584486edb3890a92bb9ef0f0d967c9ccf2e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 08:39:11 -0800 Subject: [PATCH 30/78] fix e2e ui testing deps --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 78bdf3d8e..c9a43b4b7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1373,6 +1373,7 @@ jobs: name: Install Dependencies command: | npm install -D @playwright/test + npm install @google-cloud/vertexai pip install "pytest==7.3.1" pip install "pytest-retry==1.6.3" pip install "pytest-asyncio==0.21.1" From f3ffa675536b57951451b3d746358904643cd031 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 08:45:14 -0800 Subject: [PATCH 31/78] fix e2e ui testing --- tests/proxy_admin_ui_tests/playwright.config.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/proxy_admin_ui_tests/playwright.config.ts b/tests/proxy_admin_ui_tests/playwright.config.ts index ba15d5458..3be77a319 100644 --- a/tests/proxy_admin_ui_tests/playwright.config.ts +++ b/tests/proxy_admin_ui_tests/playwright.config.ts @@ -13,7 +13,8 @@ import { defineConfig, devices } from '@playwright/test'; */ export default defineConfig({ testDir: './e2e_ui_tests', - testIgnore: '**/tests/pass_through_tests/**', + testIgnore: ['**/tests/pass_through_tests/**', '../pass_through_tests/**/*'], + testMatch: '**/*.spec.ts', // Only run files ending in .spec.ts /* Run tests in files in parallel */ fullyParallel: true, /* Fail the build on CI if you accidentally left test.only in the source code. */ From 6b6353d4e75dd41c44de50b577cb5082bc81bccf Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 08:50:10 -0800 Subject: [PATCH 32/78] fix e2e ui testing, only run e2e ui testing in playwright --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c9a43b4b7..d33f62cf3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1435,7 +1435,7 @@ jobs: - run: name: Run Playwright Tests command: | - npx playwright test --reporter=html --output=test-results + npx playwright test e2e_ui_tests/ --reporter=html --output=test-results no_output_timeout: 120m - store_test_results: path: test-results From 424b8b0231e3ed0f42790a05a216c63dcdc1afaa Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 23 Nov 2024 22:37:16 +0530 Subject: [PATCH 33/78] Litellm dev 11 23 2024 (#6881) * build(ui/create_key_button.tsx): support adding tags for cost tracking/routing when making key * LiteLLM Minor Fixes & Improvements (11/23/2024) (#6870) * feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. * fix(utils.py): allow disabling end user cost tracking with new param Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small * docs(configs.md): add disable_end_user_cost_tracking reference to docs * feat(key_management_endpoints.py): add support for restricting access to `/key/generate` by team/proxy level role Enables admin to restrict key creation, and assign team admins to handle distributing keys * test(test_key_management.py): add unit testing for personal / team key restriction checks * docs: add docs on restricting key creation * docs(finetuned_models.md): add new guide on calling finetuned models * docs(input.md): cleanup anthropic supported params Closes https://github.com/BerriAI/litellm/issues/6856 * test(test_embedding.py): add test for passing extra headers via embedding * feat(cohere/embed): pass client to async embedding * feat(rerank.py): add `/v1/rerank` if missing for cohere base url Closes https://github.com/BerriAI/litellm/issues/6844 * fix(main.py): pass extra_headers param to openai Fixes https://github.com/BerriAI/litellm/issues/6836 * fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically * fix(handler.py): fix linting error * fix: fix typing * build: add conftest to proxy_admin_ui_tests/ * test: fix test * fix: fix linting errors * test: fix test * fix: fix pass through testing * feat(key_management_endpoints.py): allow proxy_admin to enforce params on key creation allows admin to force team keys to have tags * build(ui/): show teams in leftnav + allow team admin to add new members * build(ui/): show created tags in dropdown makes it easier for admin to add tags to keys * test(test_key_management.py): fix test * test: fix test * fix playwright e2e ui test * fix e2e ui testing deps * fix: fix linting errors * fix e2e ui testing * fix e2e ui testing, only run e2e ui testing in playwright --------- Co-authored-by: Ishaan Jaff --- docs/my-website/docs/proxy/virtual_keys.md | 3 + litellm/proxy/_new_secret_config.yaml | 24 ---- .../key_management_endpoints.py | 114 ++++++++++++++---- litellm/types/utils.py | 10 +- .../test_key_management.py | 82 +++++++++++-- .../src/components/create_key_button.tsx | 36 ++++++ .../src/components/leftnav.tsx | 2 +- .../src/components/networking.tsx | 6 +- ui/litellm-dashboard/src/components/teams.tsx | 73 +++++++---- 9 files changed, 270 insertions(+), 80 deletions(-) diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 98b06d33b..5bbb6b2a0 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -820,6 +820,7 @@ litellm_settings: key_generation_settings: team_key_generation: allowed_team_member_roles: ["admin"] + required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key personal_key_generation: # maps to 'Default Team' on UI allowed_user_roles: ["proxy_admin"] ``` @@ -829,10 +830,12 @@ litellm_settings: ```python class TeamUIKeyGenerationConfig(TypedDict): allowed_team_member_roles: List[str] + required_params: List[str] # require params on `/key/generate` to be set if a team key (team_id in request) is being generated class PersonalUIKeyGenerationConfig(TypedDict): allowed_user_roles: List[LitellmUserRoles] + required_params: List[str] # require params on `/key/generate` to be set if a personal key (no team_id in request) is being generated class StandardKeyGenerationConfig(TypedDict, total=False): diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7baf2224c..7ff209094 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -11,28 +11,4 @@ model_list: model: vertex_ai/claude-3-5-sonnet-v2 vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" - - model_name: fake-openai-endpoint - litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - -router_settings: - model_group_alias: - "gpt-4-turbo": # Aliased model name - model: "gpt-4" # Actual model name in 'model_list' - hidden: true -litellm_settings: - default_team_settings: - - team_id: team-1 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1 - langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1 - - team_id: team-2 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 - langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 - langfuse_host: https://us.cloud.langfuse.com diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ab13616d5..511e5a940 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -39,16 +39,20 @@ from litellm.proxy.utils import ( handle_exception_on_proxy, ) from litellm.secret_managers.main import get_secret +from litellm.types.utils import PersonalUIKeyGenerationConfig, TeamUIKeyGenerationConfig def _is_team_key(data: GenerateKeyRequest): return data.team_id is not None -def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): +def _team_key_generation_team_member_check( + user_api_key_dict: UserAPIKeyAuth, + team_key_generation: Optional[TeamUIKeyGenerationConfig], +): if ( - litellm.key_generation_settings is None - or litellm.key_generation_settings.get("team_key_generation") is None + team_key_generation is None + or "allowed_team_member_roles" not in team_key_generation ): return True @@ -59,12 +63,7 @@ def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): ) team_member_role = user_api_key_dict.team_member.role - if ( - team_member_role - not in litellm.key_generation_settings["team_key_generation"][ # type: ignore - "allowed_team_member_roles" - ] - ): + if team_member_role not in team_key_generation["allowed_team_member_roles"]: raise HTTPException( status_code=400, detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore @@ -72,7 +71,67 @@ def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): return True -def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth): +def _key_generation_required_param_check( + data: GenerateKeyRequest, required_params: Optional[List[str]] +): + if required_params is None: + return True + + data_dict = data.model_dump(exclude_unset=True) + for param in required_params: + if param not in data_dict: + raise HTTPException( + status_code=400, + detail=f"Required param {param} not in data", + ) + return True + + +def _team_key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +): + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("team_key_generation") is None + ): + return True + + _team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore + + _team_key_generation_team_member_check( + user_api_key_dict, + team_key_generation=_team_key_generation, + ) + _key_generation_required_param_check( + data, + _team_key_generation.get("required_params"), + ) + + return True + + +def _personal_key_membership_check( + user_api_key_dict: UserAPIKeyAuth, + personal_key_generation: Optional[PersonalUIKeyGenerationConfig], +): + if ( + personal_key_generation is None + or "allowed_user_roles" not in personal_key_generation + ): + return True + + if user_api_key_dict.user_role not in personal_key_generation["allowed_user_roles"]: + raise HTTPException( + status_code=400, + detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore + ) + + return True + + +def _personal_key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +): if ( litellm.key_generation_settings is None @@ -80,16 +139,18 @@ def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth): ): return True - if ( - user_api_key_dict.user_role - not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore - "allowed_user_roles" - ] - ): - raise HTTPException( - status_code=400, - detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore - ) + _personal_key_generation = litellm.key_generation_settings["personal_key_generation"] # type: ignore + + _personal_key_membership_check( + user_api_key_dict, + personal_key_generation=_personal_key_generation, + ) + + _key_generation_required_param_check( + data, + _personal_key_generation.get("required_params"), + ) + return True @@ -99,16 +160,23 @@ def key_generation_check( """ Check if admin has restricted key creation to certain roles for teams or individuals """ - if litellm.key_generation_settings is None: + if ( + litellm.key_generation_settings is None + or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): return True ## check if key is for team or individual is_team_key = _is_team_key(data=data) if is_team_key: - return _team_key_generation_check(user_api_key_dict) + return _team_key_generation_check( + user_api_key_dict=user_api_key_dict, data=data + ) else: - return _personal_key_generation_check(user_api_key_dict=user_api_key_dict) + return _personal_key_generation_check( + user_api_key_dict=user_api_key_dict, data=data + ) router = APIRouter() diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 334894320..9fc58dff6 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1604,11 +1604,17 @@ class StandardCallbackDynamicParams(TypedDict, total=False): langsmith_base_url: Optional[str] -class TeamUIKeyGenerationConfig(TypedDict): +class KeyGenerationConfig(TypedDict, total=False): + required_params: List[ + str + ] # specify params that must be present in the key generation request + + +class TeamUIKeyGenerationConfig(KeyGenerationConfig): allowed_team_member_roles: List[str] -class PersonalUIKeyGenerationConfig(TypedDict): +class PersonalUIKeyGenerationConfig(KeyGenerationConfig): allowed_user_roles: List[str] diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 81d9fb676..0b392a268 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -551,7 +551,7 @@ def test_is_team_key(): assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id")) -def test_team_key_generation_check(): +def test_team_key_generation_team_member_check(): from litellm.proxy.management_endpoints.key_management_endpoints import ( _team_key_generation_check, ) @@ -562,22 +562,86 @@ def test_team_key_generation_check(): } assert _team_key_generation_check( - UserAPIKeyAuth( + user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", team_member=Member(role="admin", user_id="test_user_id"), - ) + ), + data=GenerateKeyRequest(), ) with pytest.raises(HTTPException): _team_key_generation_check( - UserAPIKeyAuth( + user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", user_id="test_user_id", team_member=Member(role="user", user_id="test_user_id"), + ), + data=GenerateKeyRequest(), + ) + + +@pytest.mark.parametrize( + "team_key_generation_settings, input_data, expected_result", + [ + ({"required_params": ["tags"]}, GenerateKeyRequest(tags=["test_tags"]), True), + ({}, GenerateKeyRequest(), True), + ( + {"required_params": ["models"]}, + GenerateKeyRequest(tags=["test_tags"]), + False, + ), + ], +) +@pytest.mark.parametrize("key_type", ["team_key", "personal_key"]) +def test_key_generation_required_params_check( + team_key_generation_settings, input_data, expected_result, key_type +): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _team_key_generation_check, + _personal_key_generation_check, + ) + from litellm.types.utils import ( + TeamUIKeyGenerationConfig, + StandardKeyGenerationConfig, + PersonalUIKeyGenerationConfig, + ) + from fastapi import HTTPException + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="test_user_id", + team_id="test_team_id", + team_member=Member(role="admin", user_id="test_user_id"), + ) + + if key_type == "team_key": + litellm.key_generation_settings = StandardKeyGenerationConfig( + team_key_generation=TeamUIKeyGenerationConfig( + **team_key_generation_settings ) ) + elif key_type == "personal_key": + litellm.key_generation_settings = StandardKeyGenerationConfig( + personal_key_generation=PersonalUIKeyGenerationConfig( + **team_key_generation_settings + ) + ) + + if expected_result: + if key_type == "team_key": + assert _team_key_generation_check(user_api_key_dict, input_data) + elif key_type == "personal_key": + assert _personal_key_generation_check(user_api_key_dict, input_data) + else: + if key_type == "team_key": + with pytest.raises(HTTPException): + _team_key_generation_check(user_api_key_dict, input_data) + elif key_type == "personal_key": + with pytest.raises(HTTPException): + _personal_key_generation_check(user_api_key_dict, input_data) def test_personal_key_generation_check(): @@ -591,16 +655,18 @@ def test_personal_key_generation_check(): } assert _personal_key_generation_check( - UserAPIKeyAuth( + user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" - ) + ), + data=GenerateKeyRequest(), ) with pytest.raises(HTTPException): _personal_key_generation_check( - UserAPIKeyAuth( + user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", user_id="admin", - ) + ), + data=GenerateKeyRequest(), ) diff --git a/ui/litellm-dashboard/src/components/create_key_button.tsx b/ui/litellm-dashboard/src/components/create_key_button.tsx index 0af3a064c..4f771b111 100644 --- a/ui/litellm-dashboard/src/components/create_key_button.tsx +++ b/ui/litellm-dashboard/src/components/create_key_button.tsx @@ -40,6 +40,31 @@ interface CreateKeyProps { setData: React.Dispatch>; } +const getPredefinedTags = (data: any[] | null) => { + let allTags = []; + + console.log("data:", JSON.stringify(data)); + + if (data) { + for (let key of data) { + if (key["metadata"] && key["metadata"]["tags"]) { + allTags.push(...key["metadata"]["tags"]); + } + } + } + + // Deduplicate using Set + const uniqueTags = Array.from(new Set(allTags)).map(tag => ({ + value: tag, + label: tag, + })); + + + console.log("uniqueTags:", uniqueTags); + return uniqueTags; +} + + const CreateKey: React.FC = ({ userID, team, @@ -55,6 +80,8 @@ const CreateKey: React.FC = ({ const [userModels, setUserModels] = useState([]); const [modelsToPick, setModelsToPick] = useState([]); const [keyOwner, setKeyOwner] = useState("you"); + const [predefinedTags, setPredefinedTags] = useState(getPredefinedTags(data)); + const handleOk = () => { setIsModalVisible(false); @@ -355,6 +382,15 @@ const CreateKey: React.FC = ({ placeholder="Enter metadata as JSON" /> + +