mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(vertex_ai.py): move to only passing in accepted keys by vertex ai response schema (#8992)
* fix(vertex_ai.py): common_utils.py move to only passing in accepted keys by vertex ai prevent json schema compatible keys like $id, and $comment from causing vertex ai openapi calls to fail * fix(test_vertex.py): add testing to ensure only accepted schema params passed in * fix(common_utils.py): fix linting error * test: update test * test: accept function
This commit is contained in:
parent
ae6bc8ac77
commit
578dedb8de
4 changed files with 78 additions and 30 deletions
|
@ -1,5 +1,5 @@
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import litellm
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType, Schema
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(BaseLLMException):
|
class VertexAIError(BaseLLMException):
|
||||||
|
@ -168,6 +168,9 @@ def _build_vertex_schema(parameters: dict):
|
||||||
"""
|
"""
|
||||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||||
"""
|
"""
|
||||||
|
# Get valid fields from Schema TypedDict
|
||||||
|
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||||
|
|
||||||
defs = parameters.pop("$defs", {})
|
defs = parameters.pop("$defs", {})
|
||||||
# flatten the defs
|
# flatten the defs
|
||||||
for name, value in defs.items():
|
for name, value in defs.items():
|
||||||
|
@ -181,19 +184,49 @@ def _build_vertex_schema(parameters: dict):
|
||||||
convert_anyof_null_to_nullable(parameters)
|
convert_anyof_null_to_nullable(parameters)
|
||||||
add_object_type(parameters)
|
add_object_type(parameters)
|
||||||
# Postprocessing
|
# Postprocessing
|
||||||
# 4. Suppress unnecessary title generation:
|
# Filter out fields that don't exist in Schema
|
||||||
# * https://github.com/pydantic/pydantic/issues/1051
|
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||||
# * http://cl/586221780
|
return filtered_parameters
|
||||||
strip_field(parameters, field_name="title")
|
|
||||||
|
|
||||||
strip_field(
|
|
||||||
parameters, field_name="$schema"
|
|
||||||
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
|
|
||||||
strip_field(
|
|
||||||
parameters, field_name="$id"
|
|
||||||
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
|
|
||||||
|
|
||||||
return parameters
|
def filter_schema_fields(
|
||||||
|
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Recursively filter a schema dictionary to keep only valid fields.
|
||||||
|
"""
|
||||||
|
if processed is None:
|
||||||
|
processed = set()
|
||||||
|
|
||||||
|
# Handle circular references
|
||||||
|
schema_id = id(schema_dict)
|
||||||
|
if schema_id in processed:
|
||||||
|
return schema_dict
|
||||||
|
processed.add(schema_id)
|
||||||
|
|
||||||
|
if not isinstance(schema_dict, dict):
|
||||||
|
return schema_dict
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in schema_dict.items():
|
||||||
|
if key not in valid_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key == "properties" and isinstance(value, dict):
|
||||||
|
result[key] = {
|
||||||
|
k: filter_schema_fields(v, valid_fields, processed)
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
|
elif key == "items" and isinstance(value, dict):
|
||||||
|
result[key] = filter_schema_fields(value, valid_fields, processed)
|
||||||
|
elif key == "anyOf" and isinstance(value, list):
|
||||||
|
result[key] = [
|
||||||
|
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def unpack_defs(schema, defs):
|
def unpack_defs(schema, defs):
|
||||||
|
|
|
@ -87,12 +87,27 @@ class SystemInstructions(TypedDict):
|
||||||
|
|
||||||
class Schema(TypedDict, total=False):
|
class Schema(TypedDict, total=False):
|
||||||
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
||||||
|
format: str
|
||||||
|
title: str
|
||||||
description: str
|
description: str
|
||||||
enum: List[str]
|
|
||||||
items: List["Schema"]
|
|
||||||
properties: "Schema"
|
|
||||||
required: List[str]
|
|
||||||
nullable: bool
|
nullable: bool
|
||||||
|
default: Any
|
||||||
|
items: "Schema"
|
||||||
|
minItems: str
|
||||||
|
maxItems: str
|
||||||
|
enum: List[str]
|
||||||
|
properties: Dict[str, "Schema"]
|
||||||
|
propertyOrdering: List[str]
|
||||||
|
required: List[str]
|
||||||
|
minProperties: str
|
||||||
|
maxProperties: str
|
||||||
|
minimum: float
|
||||||
|
maximum: float
|
||||||
|
minLength: str
|
||||||
|
maxLength: str
|
||||||
|
pattern: str
|
||||||
|
example: Any
|
||||||
|
anyOf: List["Schema"]
|
||||||
|
|
||||||
|
|
||||||
class FunctionDeclaration(TypedDict, total=False):
|
class FunctionDeclaration(TypedDict, total=False):
|
||||||
|
|
|
@ -5,6 +5,7 @@ IGNORE_FUNCTIONS = [
|
||||||
"_format_type",
|
"_format_type",
|
||||||
"_remove_additional_properties",
|
"_remove_additional_properties",
|
||||||
"_remove_strict_from_schema",
|
"_remove_strict_from_schema",
|
||||||
|
"filter_schema_fields",
|
||||||
"text_completion",
|
"text_completion",
|
||||||
"_check_for_os_environ_vars",
|
"_check_for_os_environ_vars",
|
||||||
"clean_message",
|
"clean_message",
|
||||||
|
|
|
@ -63,26 +63,24 @@ def test_completion_pydantic_obj_2():
|
||||||
"events": {
|
"events": {
|
||||||
"items": {
|
"items": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {"type": "string"},
|
"name": {"title": "Name", "type": "string"},
|
||||||
"date": {"type": "string"},
|
"date": {"title": "Date", "type": "string"},
|
||||||
"participants": {
|
"participants": {
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
|
"title": "Participants",
|
||||||
"type": "array",
|
"type": "array",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": ["name", "date", "participants"],
|
||||||
"name",
|
"title": "CalendarEvent",
|
||||||
"date",
|
|
||||||
"participants",
|
|
||||||
],
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
},
|
},
|
||||||
|
"title": "Events",
|
||||||
"type": "array",
|
"type": "array",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": ["events"],
|
||||||
"events",
|
"title": "EventsList",
|
||||||
],
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -91,12 +89,13 @@ def test_completion_pydantic_obj_2():
|
||||||
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
||||||
mock_post.return_value = expected_request_body
|
mock_post.return_value = expected_request_body
|
||||||
try:
|
try:
|
||||||
litellm.completion(
|
response = litellm.completion(
|
||||||
model="gemini/gemini-1.5-pro",
|
model="gemini/gemini-1.5-pro",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=EventsList,
|
response_format=EventsList,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
# print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
@ -115,7 +114,7 @@ def test_build_vertex_schema():
|
||||||
|
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"$id": "my-special-id",
|
"my-random-key": "my-random-value",
|
||||||
"properties": {
|
"properties": {
|
||||||
"recipes": {
|
"recipes": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -134,7 +133,7 @@ def test_build_vertex_schema():
|
||||||
assert new_schema["type"] == schema["type"]
|
assert new_schema["type"] == schema["type"]
|
||||||
assert new_schema["properties"] == schema["properties"]
|
assert new_schema["properties"] == schema["properties"]
|
||||||
assert "required" in new_schema and new_schema["required"] == schema["required"]
|
assert "required" in new_schema and new_schema["required"] == schema["required"]
|
||||||
assert "$id" not in new_schema
|
assert "my-random-key" not in new_schema
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue