diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index f2cd1ef557..0ff42806be 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -1,5 +1,5 @@ +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints import re -from typing import Dict, List, Literal, Optional, Tuple, Union import httpx @@ -7,7 +7,7 @@ import litellm from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.llms.base_llm.chat.transformation import BaseLLMException -from litellm.types.llms.vertex_ai import PartType +from litellm.types.llms.vertex_ai import PartType, Schema class VertexAIError(BaseLLMException): @@ -168,6 +168,9 @@ def _build_vertex_schema(parameters: dict): """ This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419 """ + # Get valid fields from Schema TypedDict + valid_schema_fields = set(get_type_hints(Schema).keys()) + defs = parameters.pop("$defs", {}) # flatten the defs for name, value in defs.items(): @@ -181,19 +184,49 @@ def _build_vertex_schema(parameters: dict): convert_anyof_null_to_nullable(parameters) add_object_type(parameters) # Postprocessing - # 4. Suppress unnecessary title generation: - # * https://github.com/pydantic/pydantic/issues/1051 - # * http://cl/586221780 - strip_field(parameters, field_name="title") + # Filter out fields that don't exist in Schema + filtered_parameters = filter_schema_fields(parameters, valid_schema_fields) + return filtered_parameters - strip_field( - parameters, field_name="$schema" - ) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors. - strip_field( - parameters, field_name="$id" - ) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors. - return parameters +def filter_schema_fields( + schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None +) -> Dict[str, Any]: + """ + Recursively filter a schema dictionary to keep only valid fields. + """ + if processed is None: + processed = set() + + # Handle circular references + schema_id = id(schema_dict) + if schema_id in processed: + return schema_dict + processed.add(schema_id) + + if not isinstance(schema_dict, dict): + return schema_dict + + result = {} + for key, value in schema_dict.items(): + if key not in valid_fields: + continue + + if key == "properties" and isinstance(value, dict): + result[key] = { + k: filter_schema_fields(v, valid_fields, processed) + for k, v in value.items() + } + elif key == "items" and isinstance(value, dict): + result[key] = filter_schema_fields(value, valid_fields, processed) + elif key == "anyOf" and isinstance(value, list): + result[key] = [ + filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore + ] + else: + result[key] = value + + return result def unpack_defs(schema, defs): diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 27d79ec992..7fa167938f 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -87,12 +87,27 @@ class SystemInstructions(TypedDict): class Schema(TypedDict, total=False): type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"] + format: str + title: str description: str - enum: List[str] - items: List["Schema"] - properties: "Schema" - required: List[str] nullable: bool + default: Any + items: "Schema" + minItems: str + maxItems: str + enum: List[str] + properties: Dict[str, "Schema"] + propertyOrdering: List[str] + required: List[str] + minProperties: str + maxProperties: str + minimum: float + maximum: float + minLength: str + maxLength: str + pattern: str + example: Any + anyOf: List["Schema"] class FunctionDeclaration(TypedDict, total=False): diff --git a/tests/code_coverage_tests/recursive_detector.py b/tests/code_coverage_tests/recursive_detector.py index b748d1a517..48fe604dbc 100644 --- a/tests/code_coverage_tests/recursive_detector.py +++ b/tests/code_coverage_tests/recursive_detector.py @@ -5,6 +5,7 @@ IGNORE_FUNCTIONS = [ "_format_type", "_remove_additional_properties", "_remove_strict_from_schema", + "filter_schema_fields", "text_completion", "_check_for_os_environ_vars", "clean_message", diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index d7a373d58d..d821fb415e 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -63,26 +63,24 @@ def test_completion_pydantic_obj_2(): "events": { "items": { "properties": { - "name": {"type": "string"}, - "date": {"type": "string"}, + "name": {"title": "Name", "type": "string"}, + "date": {"title": "Date", "type": "string"}, "participants": { "items": {"type": "string"}, + "title": "Participants", "type": "array", }, }, - "required": [ - "name", - "date", - "participants", - ], + "required": ["name", "date", "participants"], + "title": "CalendarEvent", "type": "object", }, + "title": "Events", "type": "array", } }, - "required": [ - "events", - ], + "required": ["events"], + "title": "EventsList", "type": "object", }, }, @@ -91,12 +89,13 @@ def test_completion_pydantic_obj_2(): with patch.object(client, "post", new=MagicMock()) as mock_post: mock_post.return_value = expected_request_body try: - litellm.completion( + response = litellm.completion( model="gemini/gemini-1.5-pro", messages=messages, response_format=EventsList, client=client, ) + # print(response) except Exception as e: print(e) @@ -115,7 +114,7 @@ def test_build_vertex_schema(): schema = { "type": "object", - "$id": "my-special-id", + "my-random-key": "my-random-value", "properties": { "recipes": { "type": "array", @@ -134,7 +133,7 @@ def test_build_vertex_schema(): assert new_schema["type"] == schema["type"] assert new_schema["properties"] == schema["properties"] assert "required" in new_schema and new_schema["required"] == schema["required"] - assert "$id" not in new_schema + assert "my-random-key" not in new_schema @pytest.mark.parametrize(