From 87ec54cf04677a95b0028ea377db6533d5649b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vedran=20Vidovi=C4=87?= Date: Wed, 12 Mar 2025 14:10:32 +0100 Subject: [PATCH] Optional `labels` field in Vertex AI request If the client sets the `labels` field in the request to the LiteLLM: - pass the `labels` field to the Vertex AI backend If the client sets the `metadata` field in the request to the LiteLLM: - if the `labels` field is not set, fill it with `metadata` key/value pairs for all string values --- .../llms/vertex_ai/gemini/transformation.py | 13 ++ litellm/types/llms/vertex_ai.py | 1 + .../vertex_ai/gemini/test_transformation.py | 119 ++++++++++++++++++ 3 files changed, 133 insertions(+) create mode 100644 tests/litellm/llms/vertex_ai/gemini/test_transformation.py diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index d6bafc7c60..e93d9ed2ff 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -328,6 +328,17 @@ def _transform_request_body( ) # type: ignore config_fields = GenerationConfig.__annotations__.keys() + # If the LiteLLM client sends Gemini-supported parameter "labels", add it + # as "labels" field to the request sent to the Gemini backend. + labels: Optional[dict[str, str]] = optional_params.pop("labels", None) + # If the LiteLLM client sends OpenAI-supported parameter "metadata", add it + # as "labels" field to the request sent to the Gemini backend. + if labels is None and "metadata" in litellm_params: + metadata = litellm_params["metadata"] + if "requester_metadata" in metadata: + rm = metadata["requester_metadata"] + labels = {k: v for k, v in rm.items() if type(v) is str} + filtered_params = { k: v for k, v in optional_params.items() if k in config_fields } @@ -348,6 +359,8 @@ def _transform_request_body( data["generationConfig"] = generation_config if cached_content is not None: data["cachedContent"] = cached_content + if labels is not None: + data["labels"] = labels except Exception as e: raise e diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 7024909a34..4ae06a2e3f 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -210,6 +210,7 @@ class RequestBody(TypedDict, total=False): safetySettings: List[SafetSettingsConfig] generationConfig: GenerationConfig cachedContent: str + labels: dict[str, str] class CachedContentRequestBody(TypedDict, total=False): diff --git a/tests/litellm/llms/vertex_ai/gemini/test_transformation.py b/tests/litellm/llms/vertex_ai/gemini/test_transformation.py new file mode 100644 index 0000000000..7fd528f169 --- /dev/null +++ b/tests/litellm/llms/vertex_ai/gemini/test_transformation.py @@ -0,0 +1,119 @@ +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../../../..") +) # Adds the parent directory to the system path +from litellm.llms.vertex_ai.gemini import transformation +from litellm.types.llms import openai +from litellm.types import completion +from litellm.types.llms.vertex_ai import RequestBody + +@pytest.mark.asyncio +async def test__transform_request_body_labels(): + """ + Test that Vertex AI requests use the optional Vertex AI + "labels" parameters sent by client. + """ + + # Set up the test parameters + model = "vertex_ai/gemini-1.5-pro" + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "hi"}, + ] + optional_params = { + "labels": {"lparam1": "lvalue1", "lparam2": "lvalue2"} + } + litellm_params = {} + transform_request_params = { + "messages": messages, + "model": model, + "optional_params": optional_params, + "custom_llm_provider": "vertex_ai", + "litellm_params": litellm_params, + "cached_content": None, + } + + rb: RequestBody = transformation._transform_request_body(**transform_request_params) + + # Check URL + assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}] + assert "labels" in rb and rb["labels"] == {"lparam1": "lvalue1", "lparam2": "lvalue2"} + +@pytest.mark.asyncio +async def test__transform_request_body_metadata(): + """ + Test that Vertex AI requests use the optional Open AI + "metadata" parameters sent by client. + """ + + # Set up the test parameters + model = "vertex_ai/gemini-1.5-pro" + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "hi"}, + ] + optional_params = {} + litellm_params = { + "metadata": { + "requester_metadata": {"rparam1": "rvalue1", "rparam2": "rvalue2"} + } + } + transform_request_params = { + "messages": messages, + "model": model, + "optional_params": optional_params, + "custom_llm_provider": "vertex_ai", + "litellm_params": litellm_params, + "cached_content": None, + } + + rb: RequestBody = transformation._transform_request_body(**transform_request_params) + + # Check URL + assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}] + assert "labels" in rb and rb["labels"] == {"rparam1": "rvalue1", "rparam2": "rvalue2"} + +@pytest.mark.asyncio +async def test__transform_request_body_labels_and_metadata(): + """ + Test that Vertex AI requests use the optional Vertex AI + "labels" parameters sent by client and that the "metadata" + optional Open AI parameters are ignored if the client uses + "labels" parameters. + """ + + # Set up the test parameters + model = "vertex_ai/gemini-1.5-pro" + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "hi"}, + ] + optional_params = { + "labels": {"lparam1": "lvalue1", "lparam2": "lvalue2"} + } + litellm_params = { + "metadata": { + "requester_metadata": {"rparam1": "rvalue1", "rparam2": "rvalue2"} + } + } + transform_request_params = { + "messages": messages, + "model": model, + "optional_params": optional_params, + "custom_llm_provider": "vertex_ai", + "litellm_params": litellm_params, + "cached_content": None, + } + + rb: RequestBody = transformation._transform_request_body(**transform_request_params) + + # Check URL + assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}] + assert "labels" in rb and rb["labels"] == {"lparam1": "lvalue1", "lparam2": "lvalue2"}