fix(utils.py): fix fix pydantic obj to schema creation for vertex en… (#6071)

* fix(utils.py): fix  fix pydantic obj to schema creation for vertex endpoints

Fixes https://github.com/BerriAI/litellm/issues/6027

* test(test_completion.pyu): skip test - avoid hitting gemini rate limits

* fix(common_utils.py): fix ruff linting error
This commit is contained in:
Krish Dholakia 2024-10-06 00:25:55 -04:00 committed by GitHub
parent 29da2d49d6
commit 49d8b2be46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 143 additions and 1 deletions

View file

@ -91,7 +91,7 @@ def _get_vertex_url(
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}")
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
return url, endpoint
@ -142,3 +142,116 @@ def _check_text_in_content(parts: List[PartType]) -> bool:
has_text_param = True
return has_text_param
def _build_vertex_schema(parameters: dict):
"""
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
"""
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
unpack_defs(value, defs)
unpack_defs(parameters, defs)
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_titles(parameters)
return parameters
def unpack_defs(schema, defs):
properties = schema.get("properties", None)
if properties is None:
return
for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
properties[name] = ref
continue
anyof = value.get("anyOf", None)
if anyof is not None:
for i, atype in enumerate(anyof):
ref_key = atype.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
anyof[i] = ref
continue
items = value.get("items", None)
if items is not None:
ref_key = items.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
value["items"] = ref
continue
def convert_to_nullable(schema):
anyof = schema.pop("anyOf", None)
if anyof is not None:
if len(anyof) != 2:
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
)
a, b = anyof
if a == {"type": "null"}:
schema.update(b)
elif b == {"type": "null"}:
schema.update(a)
else:
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
)
schema["nullable"] = True
properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_to_nullable(value)
items = schema.get("items", None)
if items is not None:
convert_to_nullable(items)
def add_object_type(schema):
properties = schema.get("properties", None)
if properties is not None:
schema.pop("required", None)
schema["type"] = "object"
for name, value in properties.items():
add_object_type(value)
items = schema.get("items", None)
if items is not None:
add_object_type(items)
def strip_titles(schema):
schema.pop("title", None)
properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
strip_titles(value)
items = schema.get("items", None)
if items is not None:
strip_titles(items)

View file

@ -2996,12 +2996,16 @@ def get_optional_params(
"vertex_ai_beta",
]
):
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
_build_vertex_schema,
)
old_schema = copy.deepcopy(
non_default_params["response_format"]
.get("json_schema", {})
.get("schema")
)
new_schema = _remove_additional_properties(schema=old_schema)
new_schema = _build_vertex_schema(parameters=new_schema)
non_default_params["response_format"]["json_schema"]["schema"] = new_schema
if "tools" in non_default_params and isinstance(
non_default_params, list

View file

@ -1711,6 +1711,31 @@ def test_completion_perplexity_api():
# test_completion_perplexity_api()
@pytest.mark.skip(
reason="too many requests. Hitting gemini rate limits. Convert to mock test."
)
def test_completion_pydantic_obj_2():
from pydantic import BaseModel
litellm.set_verbose = True
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
class EventsList(BaseModel):
events: list[CalendarEvent]
messages = [
{"role": "user", "content": "List important events from the 20th century."}
]
response = litellm.completion(
model="gemini/gemini-1.5-pro", messages=messages, response_format=EventsList
)
@pytest.mark.skip(reason="this test is flaky")
def test_completion_perplexity_api_2():
try: