mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
Optional labels
field in Vertex AI request
If the client sets the `labels` field in the request to the LiteLLM: - pass the `labels` field to the Vertex AI backend If the client sets the `metadata` field in the request to the LiteLLM: - if the `labels` field is not set, fill it with `metadata` key/value pairs for all string values
This commit is contained in:
parent
982d32ab91
commit
87ec54cf04
3 changed files with 133 additions and 0 deletions
|
@ -328,6 +328,17 @@ def _transform_request_body(
|
|||
) # type: ignore
|
||||
config_fields = GenerationConfig.__annotations__.keys()
|
||||
|
||||
# If the LiteLLM client sends Gemini-supported parameter "labels", add it
|
||||
# as "labels" field to the request sent to the Gemini backend.
|
||||
labels: Optional[dict[str, str]] = optional_params.pop("labels", None)
|
||||
# If the LiteLLM client sends OpenAI-supported parameter "metadata", add it
|
||||
# as "labels" field to the request sent to the Gemini backend.
|
||||
if labels is None and "metadata" in litellm_params:
|
||||
metadata = litellm_params["metadata"]
|
||||
if "requester_metadata" in metadata:
|
||||
rm = metadata["requester_metadata"]
|
||||
labels = {k: v for k, v in rm.items() if type(v) is str}
|
||||
|
||||
filtered_params = {
|
||||
k: v for k, v in optional_params.items() if k in config_fields
|
||||
}
|
||||
|
@ -348,6 +359,8 @@ def _transform_request_body(
|
|||
data["generationConfig"] = generation_config
|
||||
if cached_content is not None:
|
||||
data["cachedContent"] = cached_content
|
||||
if labels is not None:
|
||||
data["labels"] = labels
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
|
|
@ -210,6 +210,7 @@ class RequestBody(TypedDict, total=False):
|
|||
safetySettings: List[SafetSettingsConfig]
|
||||
generationConfig: GenerationConfig
|
||||
cachedContent: str
|
||||
labels: dict[str, str]
|
||||
|
||||
|
||||
class CachedContentRequestBody(TypedDict, total=False):
|
||||
|
|
119
tests/litellm/llms/vertex_ai/gemini/test_transformation.py
Normal file
119
tests/litellm/llms/vertex_ai/gemini/test_transformation.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.llms.vertex_ai.gemini import transformation
|
||||
from litellm.types.llms import openai
|
||||
from litellm.types import completion
|
||||
from litellm.types.llms.vertex_ai import RequestBody
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__transform_request_body_labels():
|
||||
"""
|
||||
Test that Vertex AI requests use the optional Vertex AI
|
||||
"labels" parameters sent by client.
|
||||
"""
|
||||
|
||||
# Set up the test parameters
|
||||
model = "vertex_ai/gemini-1.5-pro"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
optional_params = {
|
||||
"labels": {"lparam1": "lvalue1", "lparam2": "lvalue2"}
|
||||
}
|
||||
litellm_params = {}
|
||||
transform_request_params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"optional_params": optional_params,
|
||||
"custom_llm_provider": "vertex_ai",
|
||||
"litellm_params": litellm_params,
|
||||
"cached_content": None,
|
||||
}
|
||||
|
||||
rb: RequestBody = transformation._transform_request_body(**transform_request_params)
|
||||
|
||||
# Check URL
|
||||
assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}]
|
||||
assert "labels" in rb and rb["labels"] == {"lparam1": "lvalue1", "lparam2": "lvalue2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__transform_request_body_metadata():
|
||||
"""
|
||||
Test that Vertex AI requests use the optional Open AI
|
||||
"metadata" parameters sent by client.
|
||||
"""
|
||||
|
||||
# Set up the test parameters
|
||||
model = "vertex_ai/gemini-1.5-pro"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
optional_params = {}
|
||||
litellm_params = {
|
||||
"metadata": {
|
||||
"requester_metadata": {"rparam1": "rvalue1", "rparam2": "rvalue2"}
|
||||
}
|
||||
}
|
||||
transform_request_params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"optional_params": optional_params,
|
||||
"custom_llm_provider": "vertex_ai",
|
||||
"litellm_params": litellm_params,
|
||||
"cached_content": None,
|
||||
}
|
||||
|
||||
rb: RequestBody = transformation._transform_request_body(**transform_request_params)
|
||||
|
||||
# Check URL
|
||||
assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}]
|
||||
assert "labels" in rb and rb["labels"] == {"rparam1": "rvalue1", "rparam2": "rvalue2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__transform_request_body_labels_and_metadata():
|
||||
"""
|
||||
Test that Vertex AI requests use the optional Vertex AI
|
||||
"labels" parameters sent by client and that the "metadata"
|
||||
optional Open AI parameters are ignored if the client uses
|
||||
"labels" parameters.
|
||||
"""
|
||||
|
||||
# Set up the test parameters
|
||||
model = "vertex_ai/gemini-1.5-pro"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
optional_params = {
|
||||
"labels": {"lparam1": "lvalue1", "lparam2": "lvalue2"}
|
||||
}
|
||||
litellm_params = {
|
||||
"metadata": {
|
||||
"requester_metadata": {"rparam1": "rvalue1", "rparam2": "rvalue2"}
|
||||
}
|
||||
}
|
||||
transform_request_params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"optional_params": optional_params,
|
||||
"custom_llm_provider": "vertex_ai",
|
||||
"litellm_params": litellm_params,
|
||||
"cached_content": None,
|
||||
}
|
||||
|
||||
rb: RequestBody = transformation._transform_request_body(**transform_request_params)
|
||||
|
||||
# Check URL
|
||||
assert rb["contents"] == [{'parts': [{'text': 'hi'}], 'role': 'user'}, {'parts': [{'text': 'Hello! How can I assist you today?'}], 'role': 'model'}, {'parts': [{'text': 'hi'}], 'role': 'user'}]
|
||||
assert "labels" in rb and rb["labels"] == {"lparam1": "lvalue1", "lparam2": "lvalue2"}
|
Loading…
Add table
Add a link
Reference in a new issue