forked from phoenix/litellm-mirror
fix(utils.py): fix fix pydantic obj to schema creation for vertex en… (#6071)
* fix(utils.py): fix fix pydantic obj to schema creation for vertex endpoints Fixes https://github.com/BerriAI/litellm/issues/6027 * test(test_completion.pyu): skip test - avoid hitting gemini rate limits * fix(common_utils.py): fix ruff linting error
This commit is contained in:
parent
29da2d49d6
commit
49d8b2be46
3 changed files with 143 additions and 1 deletions
|
@ -91,7 +91,7 @@ def _get_vertex_url(
|
|||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
if not url or not endpoint:
|
||||
raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}")
|
||||
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
||||
return url, endpoint
|
||||
|
||||
|
||||
|
@ -142,3 +142,116 @@ def _check_text_in_content(parts: List[PartType]) -> bool:
|
|||
has_text_param = True
|
||||
|
||||
return has_text_param
|
||||
|
||||
|
||||
def _build_vertex_schema(parameters: dict):
|
||||
"""
|
||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||
"""
|
||||
defs = parameters.pop("$defs", {})
|
||||
# flatten the defs
|
||||
for name, value in defs.items():
|
||||
unpack_defs(value, defs)
|
||||
unpack_defs(parameters, defs)
|
||||
|
||||
# 5. Nullable fields:
|
||||
# * https://github.com/pydantic/pydantic/issues/1270
|
||||
# * https://stackoverflow.com/a/58841311
|
||||
# * https://github.com/pydantic/pydantic/discussions/4872
|
||||
convert_to_nullable(parameters)
|
||||
add_object_type(parameters)
|
||||
# Postprocessing
|
||||
# 4. Suppress unnecessary title generation:
|
||||
# * https://github.com/pydantic/pydantic/issues/1051
|
||||
# * http://cl/586221780
|
||||
strip_titles(parameters)
|
||||
return parameters
|
||||
|
||||
|
||||
def unpack_defs(schema, defs):
|
||||
properties = schema.get("properties", None)
|
||||
if properties is None:
|
||||
return
|
||||
|
||||
for name, value in properties.items():
|
||||
ref_key = value.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
properties[name] = ref
|
||||
continue
|
||||
|
||||
anyof = value.get("anyOf", None)
|
||||
if anyof is not None:
|
||||
for i, atype in enumerate(anyof):
|
||||
ref_key = atype.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
anyof[i] = ref
|
||||
continue
|
||||
|
||||
items = value.get("items", None)
|
||||
if items is not None:
|
||||
ref_key = items.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
value["items"] = ref
|
||||
continue
|
||||
|
||||
|
||||
def convert_to_nullable(schema):
|
||||
anyof = schema.pop("anyOf", None)
|
||||
if anyof is not None:
|
||||
if len(anyof) != 2:
|
||||
raise ValueError(
|
||||
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
||||
"Please provide an `Optional` type or a non-Union type."
|
||||
)
|
||||
a, b = anyof
|
||||
if a == {"type": "null"}:
|
||||
schema.update(b)
|
||||
elif b == {"type": "null"}:
|
||||
schema.update(a)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
||||
"Please provide an `Optional` type or a non-Union type."
|
||||
)
|
||||
schema["nullable"] = True
|
||||
|
||||
properties = schema.get("properties", None)
|
||||
if properties is not None:
|
||||
for name, value in properties.items():
|
||||
convert_to_nullable(value)
|
||||
|
||||
items = schema.get("items", None)
|
||||
if items is not None:
|
||||
convert_to_nullable(items)
|
||||
|
||||
|
||||
def add_object_type(schema):
|
||||
properties = schema.get("properties", None)
|
||||
if properties is not None:
|
||||
schema.pop("required", None)
|
||||
schema["type"] = "object"
|
||||
for name, value in properties.items():
|
||||
add_object_type(value)
|
||||
|
||||
items = schema.get("items", None)
|
||||
if items is not None:
|
||||
add_object_type(items)
|
||||
|
||||
|
||||
def strip_titles(schema):
|
||||
schema.pop("title", None)
|
||||
|
||||
properties = schema.get("properties", None)
|
||||
if properties is not None:
|
||||
for name, value in properties.items():
|
||||
strip_titles(value)
|
||||
|
||||
items = schema.get("items", None)
|
||||
if items is not None:
|
||||
strip_titles(items)
|
||||
|
|
|
@ -2996,12 +2996,16 @@ def get_optional_params(
|
|||
"vertex_ai_beta",
|
||||
]
|
||||
):
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||
_build_vertex_schema,
|
||||
)
|
||||
old_schema = copy.deepcopy(
|
||||
non_default_params["response_format"]
|
||||
.get("json_schema", {})
|
||||
.get("schema")
|
||||
)
|
||||
new_schema = _remove_additional_properties(schema=old_schema)
|
||||
new_schema = _build_vertex_schema(parameters=new_schema)
|
||||
non_default_params["response_format"]["json_schema"]["schema"] = new_schema
|
||||
if "tools" in non_default_params and isinstance(
|
||||
non_default_params, list
|
||||
|
|
|
@ -1711,6 +1711,31 @@ def test_completion_perplexity_api():
|
|||
# test_completion_perplexity_api()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="too many requests. Hitting gemini rate limits. Convert to mock test."
|
||||
)
|
||||
def test_completion_pydantic_obj_2():
|
||||
from pydantic import BaseModel
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
class EventsList(BaseModel):
|
||||
events: list[CalendarEvent]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "List important events from the 20th century."}
|
||||
]
|
||||
|
||||
response = litellm.completion(
|
||||
model="gemini/gemini-1.5-pro", messages=messages, response_format=EventsList
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="this test is flaky")
|
||||
def test_completion_perplexity_api_2():
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue