mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(vertex_ai.py): move to only passing in accepted keys by vertex ai response schema (#8992)
* fix(vertex_ai.py): common_utils.py move to only passing in accepted keys by vertex ai prevent json schema compatible keys like $id, and $comment from causing vertex ai openapi calls to fail * fix(test_vertex.py): add testing to ensure only accepted schema params passed in * fix(common_utils.py): fix linting error * test: update test * test: accept function
This commit is contained in:
parent
4a128cfd64
commit
8e3c7b2de0
4 changed files with 78 additions and 30 deletions
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
import re
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -7,7 +7,7 @@ import litellm
|
|||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
from litellm.types.llms.vertex_ai import PartType, Schema
|
||||
|
||||
|
||||
class VertexAIError(BaseLLMException):
|
||||
|
@ -168,6 +168,9 @@ def _build_vertex_schema(parameters: dict):
|
|||
"""
|
||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||
"""
|
||||
# Get valid fields from Schema TypedDict
|
||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||
|
||||
defs = parameters.pop("$defs", {})
|
||||
# flatten the defs
|
||||
for name, value in defs.items():
|
||||
|
@ -181,19 +184,49 @@ def _build_vertex_schema(parameters: dict):
|
|||
convert_anyof_null_to_nullable(parameters)
|
||||
add_object_type(parameters)
|
||||
# Postprocessing
|
||||
# 4. Suppress unnecessary title generation:
|
||||
# * https://github.com/pydantic/pydantic/issues/1051
|
||||
# * http://cl/586221780
|
||||
strip_field(parameters, field_name="title")
|
||||
# Filter out fields that don't exist in Schema
|
||||
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
return filtered_parameters
|
||||
|
||||
strip_field(
|
||||
parameters, field_name="$schema"
|
||||
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
|
||||
strip_field(
|
||||
parameters, field_name="$id"
|
||||
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
|
||||
|
||||
return parameters
|
||||
def filter_schema_fields(
|
||||
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively filter a schema dictionary to keep only valid fields.
|
||||
"""
|
||||
if processed is None:
|
||||
processed = set()
|
||||
|
||||
# Handle circular references
|
||||
schema_id = id(schema_dict)
|
||||
if schema_id in processed:
|
||||
return schema_dict
|
||||
processed.add(schema_id)
|
||||
|
||||
if not isinstance(schema_dict, dict):
|
||||
return schema_dict
|
||||
|
||||
result = {}
|
||||
for key, value in schema_dict.items():
|
||||
if key not in valid_fields:
|
||||
continue
|
||||
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
result[key] = {
|
||||
k: filter_schema_fields(v, valid_fields, processed)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif key == "items" and isinstance(value, dict):
|
||||
result[key] = filter_schema_fields(value, valid_fields, processed)
|
||||
elif key == "anyOf" and isinstance(value, list):
|
||||
result[key] = [
|
||||
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def unpack_defs(schema, defs):
|
||||
|
|
|
@ -87,12 +87,27 @@ class SystemInstructions(TypedDict):
|
|||
|
||||
class Schema(TypedDict, total=False):
|
||||
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
||||
format: str
|
||||
title: str
|
||||
description: str
|
||||
enum: List[str]
|
||||
items: List["Schema"]
|
||||
properties: "Schema"
|
||||
required: List[str]
|
||||
nullable: bool
|
||||
default: Any
|
||||
items: "Schema"
|
||||
minItems: str
|
||||
maxItems: str
|
||||
enum: List[str]
|
||||
properties: Dict[str, "Schema"]
|
||||
propertyOrdering: List[str]
|
||||
required: List[str]
|
||||
minProperties: str
|
||||
maxProperties: str
|
||||
minimum: float
|
||||
maximum: float
|
||||
minLength: str
|
||||
maxLength: str
|
||||
pattern: str
|
||||
example: Any
|
||||
anyOf: List["Schema"]
|
||||
|
||||
|
||||
class FunctionDeclaration(TypedDict, total=False):
|
||||
|
|
|
@ -5,6 +5,7 @@ IGNORE_FUNCTIONS = [
|
|||
"_format_type",
|
||||
"_remove_additional_properties",
|
||||
"_remove_strict_from_schema",
|
||||
"filter_schema_fields",
|
||||
"text_completion",
|
||||
"_check_for_os_environ_vars",
|
||||
"clean_message",
|
||||
|
|
|
@ -63,26 +63,24 @@ def test_completion_pydantic_obj_2():
|
|||
"events": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"date": {"title": "Date", "type": "string"},
|
||||
"participants": {
|
||||
"items": {"type": "string"},
|
||||
"title": "Participants",
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"date",
|
||||
"participants",
|
||||
],
|
||||
"required": ["name", "date", "participants"],
|
||||
"title": "CalendarEvent",
|
||||
"type": "object",
|
||||
},
|
||||
"title": "Events",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"events",
|
||||
],
|
||||
"required": ["events"],
|
||||
"title": "EventsList",
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
|
@ -91,12 +89,13 @@ def test_completion_pydantic_obj_2():
|
|||
with patch.object(client, "post", new=MagicMock()) as mock_post:
|
||||
mock_post.return_value = expected_request_body
|
||||
try:
|
||||
litellm.completion(
|
||||
response = litellm.completion(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
messages=messages,
|
||||
response_format=EventsList,
|
||||
client=client,
|
||||
)
|
||||
# print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
@ -115,7 +114,7 @@ def test_build_vertex_schema():
|
|||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"$id": "my-special-id",
|
||||
"my-random-key": "my-random-value",
|
||||
"properties": {
|
||||
"recipes": {
|
||||
"type": "array",
|
||||
|
@ -134,7 +133,7 @@ def test_build_vertex_schema():
|
|||
assert new_schema["type"] == schema["type"]
|
||||
assert new_schema["properties"] == schema["properties"]
|
||||
assert "required" in new_schema and new_schema["required"] == schema["required"]
|
||||
assert "$id" not in new_schema
|
||||
assert "my-random-key" not in new_schema
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue