diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 5fdd8e40c..11424e2e4 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -91,7 +91,7 @@ def _get_vertex_url( url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" if not url or not endpoint: - raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}") + raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") return url, endpoint @@ -142,3 +142,116 @@ def _check_text_in_content(parts: List[PartType]) -> bool: has_text_param = True return has_text_param + + +def _build_vertex_schema(parameters: dict): + """ + This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419 + """ + defs = parameters.pop("$defs", {}) + # flatten the defs + for name, value in defs.items(): + unpack_defs(value, defs) + unpack_defs(parameters, defs) + + # 5. Nullable fields: + # * https://github.com/pydantic/pydantic/issues/1270 + # * https://stackoverflow.com/a/58841311 + # * https://github.com/pydantic/pydantic/discussions/4872 + convert_to_nullable(parameters) + add_object_type(parameters) + # Postprocessing + # 4. Suppress unnecessary title generation: + # * https://github.com/pydantic/pydantic/issues/1051 + # * http://cl/586221780 + strip_titles(parameters) + return parameters + + +def unpack_defs(schema, defs): + properties = schema.get("properties", None) + if properties is None: + return + + for name, value in properties.items(): + ref_key = value.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + properties[name] = ref + continue + + anyof = value.get("anyOf", None) + if anyof is not None: + for i, atype in enumerate(anyof): + ref_key = atype.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + anyof[i] = ref + continue + + items = value.get("items", None) + if items is not None: + ref_key = items.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + value["items"] = ref + continue + + +def convert_to_nullable(schema): + anyof = schema.pop("anyOf", None) + if anyof is not None: + if len(anyof) != 2: + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) + a, b = anyof + if a == {"type": "null"}: + schema.update(b) + elif b == {"type": "null"}: + schema.update(a) + else: + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) + schema["nullable"] = True + + properties = schema.get("properties", None) + if properties is not None: + for name, value in properties.items(): + convert_to_nullable(value) + + items = schema.get("items", None) + if items is not None: + convert_to_nullable(items) + + +def add_object_type(schema): + properties = schema.get("properties", None) + if properties is not None: + schema.pop("required", None) + schema["type"] = "object" + for name, value in properties.items(): + add_object_type(value) + + items = schema.get("items", None) + if items is not None: + add_object_type(items) + + +def strip_titles(schema): + schema.pop("title", None) + + properties = schema.get("properties", None) + if properties is not None: + for name, value in properties.items(): + strip_titles(value) + + items = schema.get("items", None) + if items is not None: + strip_titles(items) diff --git a/litellm/utils.py b/litellm/utils.py index cc52c1f56..c2d0f0b9f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2996,12 +2996,16 @@ def get_optional_params( "vertex_ai_beta", ] ): + from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import ( + _build_vertex_schema, + ) old_schema = copy.deepcopy( non_default_params["response_format"] .get("json_schema", {}) .get("schema") ) new_schema = _remove_additional_properties(schema=old_schema) + new_schema = _build_vertex_schema(parameters=new_schema) non_default_params["response_format"]["json_schema"]["schema"] = new_schema if "tools" in non_default_params and isinstance( non_default_params, list diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 34a0baffb..b573a688b 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1711,6 +1711,31 @@ def test_completion_perplexity_api(): # test_completion_perplexity_api() +@pytest.mark.skip( + reason="too many requests. Hitting gemini rate limits. Convert to mock test." +) +def test_completion_pydantic_obj_2(): + from pydantic import BaseModel + + litellm.set_verbose = True + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + class EventsList(BaseModel): + events: list[CalendarEvent] + + messages = [ + {"role": "user", "content": "List important events from the 20th century."} + ] + + response = litellm.completion( + model="gemini/gemini-1.5-pro", messages=messages, response_format=EventsList + ) + + @pytest.mark.skip(reason="this test is flaky") def test_completion_perplexity_api_2(): try: