mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* fix #9783: Retain schema field ordering for google gemini and vertex (#9828) * test: update test * refactor(groq.py): initial commit migrating groq to base_llm_http_handler * fix(streaming_chunk_builder_utils.py): fix how tool content is combined Fixes https://github.com/BerriAI/litellm/issues/10034 * fix(vertex_ai/common_utils.py): prevent infinite loop in helper function * fix(groq/chat/transformation.py): handle groq streaming errors correctly * fix(groq/chat/transformation.py): handle max_retries --------- Co-authored-by: Adrian Lyjak <adrian@chatmeter.com>
This commit is contained in:
parent
1b9b745cae
commit
fdfa1108a6
12 changed files with 493 additions and 201 deletions
|
@ -106,74 +106,64 @@ class ChunkProcessor:
|
|||
def get_combined_tool_content(
|
||||
self, tool_call_chunks: List[Dict[str, Any]]
|
||||
) -> List[ChatCompletionMessageToolCall]:
|
||||
argument_list: List[str] = []
|
||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
||||
id = None
|
||||
name = None
|
||||
type = None
|
||||
tool_calls_list: List[ChatCompletionMessageToolCall] = []
|
||||
prev_index = None
|
||||
prev_name = None
|
||||
prev_id = None
|
||||
curr_id = None
|
||||
curr_index = 0
|
||||
tool_call_map: Dict[
|
||||
int, Dict[str, Any]
|
||||
] = {} # Map to store tool calls by index
|
||||
|
||||
for chunk in tool_call_chunks:
|
||||
choices = chunk["choices"]
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
tool_calls = delta.get("tool_calls", "")
|
||||
# Check if a tool call is present
|
||||
if tool_calls and tool_calls[0].function is not None:
|
||||
if tool_calls[0].id:
|
||||
id = tool_calls[0].id
|
||||
curr_id = id
|
||||
if prev_id is None:
|
||||
prev_id = curr_id
|
||||
if tool_calls[0].index:
|
||||
curr_index = tool_calls[0].index
|
||||
if tool_calls[0].function.arguments:
|
||||
# Now, tool_calls is expected to be a dictionary
|
||||
arguments = tool_calls[0].function.arguments
|
||||
argument_list.append(arguments)
|
||||
if tool_calls[0].function.name:
|
||||
name = tool_calls[0].function.name
|
||||
if tool_calls[0].type:
|
||||
type = tool_calls[0].type
|
||||
if prev_index is None:
|
||||
prev_index = curr_index
|
||||
if prev_name is None:
|
||||
prev_name = name
|
||||
if curr_index != prev_index: # new tool call
|
||||
combined_arguments = "".join(argument_list)
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not tool_call or not hasattr(tool_call, "function"):
|
||||
continue
|
||||
|
||||
index = getattr(tool_call, "index", 0)
|
||||
if index not in tool_call_map:
|
||||
tool_call_map[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"type": None,
|
||||
"arguments": [],
|
||||
}
|
||||
|
||||
if hasattr(tool_call, "id") and tool_call.id:
|
||||
tool_call_map[index]["id"] = tool_call.id
|
||||
if hasattr(tool_call, "type") and tool_call.type:
|
||||
tool_call_map[index]["type"] = tool_call.type
|
||||
if hasattr(tool_call, "function"):
|
||||
if (
|
||||
hasattr(tool_call.function, "name")
|
||||
and tool_call.function.name
|
||||
):
|
||||
tool_call_map[index]["name"] = tool_call.function.name
|
||||
if (
|
||||
hasattr(tool_call.function, "arguments")
|
||||
and tool_call.function.arguments
|
||||
):
|
||||
tool_call_map[index]["arguments"].append(
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Convert the map to a list of tool calls
|
||||
for index in sorted(tool_call_map.keys()):
|
||||
tool_call_data = tool_call_map[index]
|
||||
if tool_call_data["id"] and tool_call_data["name"]:
|
||||
combined_arguments = "".join(tool_call_data["arguments"]) or "{}"
|
||||
tool_calls_list.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=prev_id,
|
||||
id=tool_call_data["id"],
|
||||
function=Function(
|
||||
arguments=combined_arguments,
|
||||
name=prev_name,
|
||||
name=tool_call_data["name"],
|
||||
),
|
||||
type=type,
|
||||
type=tool_call_data["type"] or "function",
|
||||
index=index,
|
||||
)
|
||||
)
|
||||
argument_list = [] # reset
|
||||
prev_index = curr_index
|
||||
prev_id = curr_id
|
||||
prev_name = name
|
||||
|
||||
combined_arguments = (
|
||||
"".join(argument_list) or "{}"
|
||||
) # base case, return empty dict
|
||||
|
||||
tool_calls_list.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=id,
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=combined_arguments,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls_list
|
||||
|
||||
|
|
|
@ -230,6 +230,7 @@ class BaseLLMHTTPHandler:
|
|||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
|
@ -267,6 +268,9 @@ class BaseLLMHTTPHandler:
|
|||
headers=headers,
|
||||
)
|
||||
|
||||
if extra_body is not None:
|
||||
data = {**data, **extra_body}
|
||||
|
||||
headers = provider_config.sign_request(
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
|
|
|
@ -57,6 +57,14 @@ class GroqChatConfig(OpenAIGPTConfig):
|
|||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = super().get_supported_openai_params(model)
|
||||
try:
|
||||
base_params.remove("max_retries")
|
||||
except ValueError:
|
||||
pass
|
||||
return base_params
|
||||
|
||||
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
|
||||
for idx, message in enumerate(messages):
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -165,9 +165,18 @@ def _check_text_in_content(parts: List[PartType]) -> bool:
|
|||
return has_text_param
|
||||
|
||||
|
||||
def _build_vertex_schema(parameters: dict):
|
||||
def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
|
||||
"""
|
||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||
|
||||
Updates the input parameters, removing extraneous fields, adjusting types, unwinding $defs, and adding propertyOrdering if specified, returning the updated parameters.
|
||||
|
||||
Parameters:
|
||||
parameters: dict - the json schema to build from
|
||||
add_property_ordering: bool - whether to add propertyOrdering to the schema. This is only applicable to schemas for structured outputs. See
|
||||
set_schema_property_ordering for more details.
|
||||
Returns:
|
||||
parameters: dict - the input parameters, modified in place
|
||||
"""
|
||||
# Get valid fields from Schema TypedDict
|
||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||
|
@ -186,8 +195,40 @@ def _build_vertex_schema(parameters: dict):
|
|||
add_object_type(parameters)
|
||||
# Postprocessing
|
||||
# Filter out fields that don't exist in Schema
|
||||
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
return filtered_parameters
|
||||
parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
|
||||
if add_property_ordering:
|
||||
set_schema_property_ordering(parameters)
|
||||
return parameters
|
||||
|
||||
|
||||
def set_schema_property_ordering(
|
||||
schema: Dict[str, Any], depth: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
vertex ai and generativeai apis order output of fields alphabetically, unless you specify the order.
|
||||
python dicts retain order, so we just use that. Note that this field only applies to structured outputs, and not tools.
|
||||
Function tools are not afflicted by the same alphabetical ordering issue, (the order of keys returned seems to be arbitrary, up to the model)
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.cachedContents#Schema.FIELDS.property_ordering
|
||||
|
||||
Args:
|
||||
schema: The schema dictionary to process
|
||||
depth: Current recursion depth to prevent infinite loops
|
||||
"""
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
raise ValueError(
|
||||
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||
)
|
||||
|
||||
if "properties" in schema and isinstance(schema["properties"], dict):
|
||||
# retain propertyOrdering as an escape hatch if user already specifies it
|
||||
if "propertyOrdering" not in schema:
|
||||
schema["propertyOrdering"] = [k for k, v in schema["properties"].items()]
|
||||
for k, v in schema["properties"].items():
|
||||
set_schema_property_ordering(v, depth + 1)
|
||||
if "items" in schema:
|
||||
set_schema_property_ordering(schema["items"], depth + 1)
|
||||
return schema
|
||||
|
||||
|
||||
def filter_schema_fields(
|
||||
|
|
|
@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
"extra_headers",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs", # Added this to list of supported openAI params
|
||||
"top_logprobs",
|
||||
"modalities",
|
||||
]
|
||||
|
||||
|
@ -313,9 +313,10 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
if isinstance(old_schema, list):
|
||||
for item in old_schema:
|
||||
if isinstance(item, dict):
|
||||
item = _build_vertex_schema(parameters=item)
|
||||
item = _build_vertex_schema(parameters=item, add_property_ordering=True)
|
||||
|
||||
elif isinstance(old_schema, dict):
|
||||
old_schema = _build_vertex_schema(parameters=old_schema)
|
||||
old_schema = _build_vertex_schema(parameters=old_schema, add_property_ordering=True)
|
||||
return old_schema
|
||||
|
||||
def apply_response_schema_transformation(self, value: dict, optional_params: dict):
|
||||
|
|
|
@ -1622,24 +1622,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
response = groq_chat_completions.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "aiohttp_openai":
|
||||
# NEW aiohttp provider for 10-100x higher RPS
|
||||
|
@ -2658,9 +2656,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
optional_params[
|
||||
"aws_region_name"
|
||||
] = aws_bedrock_client.meta.region_name
|
||||
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
|
@ -4367,9 +4365,9 @@ def adapter_completion(
|
|||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
|
@ -5789,9 +5787,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(content_chunks) > 0:
|
||||
response["choices"][0]["message"]["content"] = (
|
||||
processor.get_combined_content(content_chunks)
|
||||
)
|
||||
response["choices"][0]["message"][
|
||||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
|
@ -5802,9 +5800,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"]["reasoning_content"] = (
|
||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
)
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
|
|
|
@ -18,6 +18,7 @@ IGNORE_FUNCTIONS = [
|
|||
"_serialize", # we now set a max depth for this
|
||||
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
|
||||
"_sanitize_value", # testing added for circular reference
|
||||
"set_schema_property_ordering", # testing added for infinite recursion
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
ChatCompletionMessageToolCall,
|
||||
Delta,
|
||||
Function,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
def test_get_combined_tool_content():
|
||||
chunks = [
|
||||
ModelResponseStream(
|
||||
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||
created=1744771912,
|
||||
model="llama-3.3-70b-versatile",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
provider_specific_fields=None,
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_m87w",
|
||||
function=Function(
|
||||
arguments='{"location": "San Francisco", "unit": "imperial"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
audio=None,
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
provider_specific_fields=None,
|
||||
stream_options=None,
|
||||
),
|
||||
ModelResponseStream(
|
||||
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||
created=1744771912,
|
||||
model="llama-3.3-70b-versatile",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
provider_specific_fields=None,
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_rrns",
|
||||
function=Function(
|
||||
arguments='{"location": "Tokyo", "unit": "metric"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=1,
|
||||
)
|
||||
],
|
||||
audio=None,
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
provider_specific_fields=None,
|
||||
stream_options=None,
|
||||
),
|
||||
ModelResponseStream(
|
||||
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||
created=1744771912,
|
||||
model="llama-3.3-70b-versatile",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
provider_specific_fields=None,
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_0k29",
|
||||
function=Function(
|
||||
arguments='{"location": "Paris", "unit": "metric"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=2,
|
||||
)
|
||||
],
|
||||
audio=None,
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
provider_specific_fields=None,
|
||||
stream_options=None,
|
||||
),
|
||||
]
|
||||
chunk_processor = ChunkProcessor(chunks=chunks)
|
||||
|
||||
tool_calls_list = chunk_processor.get_combined_tool_content(chunks)
|
||||
assert tool_calls_list == [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_m87w",
|
||||
function=Function(
|
||||
arguments='{"location": "San Francisco", "unit": "imperial"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_rrns",
|
||||
function=Function(
|
||||
arguments='{"location": "Tokyo", "unit": "metric"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=1,
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_0k29",
|
||||
function=Function(
|
||||
arguments='{"location": "Paris", "unit": "metric"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
type="function",
|
||||
index=2,
|
||||
),
|
||||
]
|
|
@ -9,6 +9,8 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
|||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.utils import ChoiceLogprobs
|
||||
from pydantic import BaseModel
|
||||
from typing import List, cast
|
||||
|
||||
|
||||
def test_top_logprobs():
|
||||
|
@ -62,3 +64,160 @@ def test_get_model_name_from_gemini_spec_model():
|
|||
model = "gemini/ft-uuid-123"
|
||||
result = VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
||||
assert result == "ft-uuid-123"
|
||||
|
||||
|
||||
|
||||
def test_vertex_ai_response_schema_dict():
|
||||
v = VertexGeminiConfig()
|
||||
transformed_request = v.map_openai_params(
|
||||
non_default_params={
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "math_reasoning",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thought": {"type": "string"},
|
||||
"output": {"type": "string"},
|
||||
},
|
||||
"required": ["thought", "output"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"final_answer": {"type": "string"},
|
||||
},
|
||||
"required": ["steps", "final_answer"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
optional_params={},
|
||||
model="gemini-2.0-flash-lite",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
schema = transformed_request["response_schema"]
|
||||
# should add propertyOrdering
|
||||
assert schema["propertyOrdering"] == ["steps", "final_answer"]
|
||||
# should add propertyOrdering (recursively, including array items)
|
||||
assert schema["properties"]["steps"]["items"]["propertyOrdering"] == [
|
||||
"thought",
|
||||
"output",
|
||||
]
|
||||
# should strip strict and additionalProperties
|
||||
assert "strict" not in schema
|
||||
assert "additionalProperties" not in schema
|
||||
# validate the whole thing to catch regressions
|
||||
assert transformed_request["response_schema"] == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thought": {"type": "string"},
|
||||
"output": {"type": "string"},
|
||||
},
|
||||
"required": ["thought", "output"],
|
||||
"propertyOrdering": ["thought", "output"],
|
||||
},
|
||||
},
|
||||
"final_answer": {"type": "string"},
|
||||
},
|
||||
"required": ["steps", "final_answer"],
|
||||
"propertyOrdering": ["steps", "final_answer"],
|
||||
}
|
||||
|
||||
|
||||
class MathReasoning(BaseModel):
|
||||
steps: List["Step"]
|
||||
final_answer: str
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
thought: str
|
||||
output: str
|
||||
|
||||
|
||||
def test_vertex_ai_response_schema_defs():
|
||||
v = VertexGeminiConfig()
|
||||
|
||||
schema = cast(dict, v.get_json_schema_from_pydantic_object(MathReasoning))
|
||||
|
||||
# pydantic conversion by default adds $defs to the schema, make sure this is still the case, otherwise this test isn't really testing anything
|
||||
assert "$defs" in schema["json_schema"]["schema"]
|
||||
|
||||
transformed_request = v.map_openai_params(
|
||||
non_default_params={
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"response_format": schema,
|
||||
},
|
||||
optional_params={},
|
||||
model="gemini-2.0-flash-lite",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert "$defs" not in transformed_request["response_schema"]
|
||||
assert transformed_request["response_schema"] == {
|
||||
"title": "MathReasoning",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"title": "Steps",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"title": "Step",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thought": {"title": "Thought", "type": "string"},
|
||||
"output": {"title": "Output", "type": "string"},
|
||||
},
|
||||
"required": ["thought", "output"],
|
||||
"propertyOrdering": ["thought", "output"],
|
||||
},
|
||||
},
|
||||
"final_answer": {"title": "Final Answer", "type": "string"},
|
||||
},
|
||||
"required": ["steps", "final_answer"],
|
||||
"propertyOrdering": ["steps", "final_answer"],
|
||||
}
|
||||
|
||||
|
||||
def test_vertex_ai_retain_property_ordering():
|
||||
v = VertexGeminiConfig()
|
||||
transformed_request = v.map_openai_params(
|
||||
non_default_params={
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "math_reasoning",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output": {"type": "string"},
|
||||
"thought": {"type": "string"},
|
||||
},
|
||||
"propertyOrdering": ["thought", "output"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
optional_params={},
|
||||
model="gemini-2.0-flash-lite",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
schema = transformed_request["response_schema"]
|
||||
# should leave existing value alone, despite dictionary ordering
|
||||
assert schema["propertyOrdering"] == ["thought", "output"]
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
convert_anyof_null_to_nullable,
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
convert_anyof_null_to_nullable
|
||||
set_schema_property_ordering,
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,31 +47,19 @@ async def test_get_vertex_location_from_url():
|
|||
location = get_vertex_location_from_url(url)
|
||||
assert location is None
|
||||
|
||||
|
||||
def test_basic_anyof_conversion():
|
||||
"""Test basic conversion of anyOf with 'null'."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
"properties": {"example": {"anyOf": [{"type": "string"}, {"type": "null"}]}},
|
||||
}
|
||||
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"anyOf": [
|
||||
{"type": "string", "nullable": True}
|
||||
]
|
||||
}
|
||||
}
|
||||
"properties": {"example": {"anyOf": [{"type": "string", "nullable": True}]}},
|
||||
}
|
||||
assert schema == expected
|
||||
|
||||
|
@ -85,12 +76,12 @@ def test_nested_anyof_conversion():
|
|||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "string"}},
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
{"type": "null"},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
@ -103,16 +94,21 @@ def test_nested_anyof_conversion():
|
|||
"properties": {
|
||||
"inner": {
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
||||
{"type": "string", "nullable": True}
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"nullable": True,
|
||||
},
|
||||
{"type": "string", "nullable": True},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
assert schema == expected
|
||||
|
||||
|
||||
def test_anyof_with_excessive_nesting():
|
||||
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||
# generate a schema with excessive nesting
|
||||
|
@ -121,22 +117,20 @@ def test_anyof_with_excessive_nesting():
|
|||
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||
current["properties"] = {
|
||||
"nested": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
],
|
||||
"properties": {}
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"properties": {},
|
||||
}
|
||||
}
|
||||
current = current["properties"]["nested"]
|
||||
|
||||
|
||||
# running the conversion will raise an error
|
||||
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting.",
|
||||
):
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_supports_system_message():
|
||||
"""Test get_supports_system_message with different models"""
|
||||
|
@ -153,93 +147,20 @@ async def test_get_supports_system_message():
|
|||
model="random-model-name", custom_llm_provider="vertex_ai"
|
||||
)
|
||||
assert result == False
|
||||
def test_basic_anyof_conversion():
|
||||
"""Test basic conversion of anyOf with 'null'."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"anyOf": [
|
||||
{"type": "string", "nullable": True}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
assert schema == expected
|
||||
|
||||
|
||||
def test_nested_anyof_conversion():
|
||||
"""Test nested conversion with 'anyOf' inside properties."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"outer": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inner": {
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "string"}},
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"outer": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inner": {
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
||||
{"type": "string", "nullable": True}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert schema == expected
|
||||
|
||||
def test_anyof_with_excessive_nesting():
|
||||
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||
def test_set_schema_property_ordering_with_excessive_nesting():
|
||||
"""Test set_schema_property_ordering with excessive nesting > max levels +1 deep."""
|
||||
# generate a schema with excessive nesting
|
||||
schema = {"type": "object", "properties": {}}
|
||||
current = schema
|
||||
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||
current["properties"] = {
|
||||
"nested": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
],
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
current["properties"] = {"nested": {"type": "object", "properties": {}}}
|
||||
current = current["properties"]["nested"]
|
||||
|
||||
|
||||
# running the conversion will raise an error
|
||||
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
# running the function will raise an error
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting.",
|
||||
):
|
||||
set_schema_property_ordering(schema)
|
||||
|
|
|
@ -71,6 +71,11 @@ def test_completion_pydantic_obj_2():
|
|||
"type": "array",
|
||||
},
|
||||
},
|
||||
"propertyOrdering": [
|
||||
"name",
|
||||
"date",
|
||||
"participants",
|
||||
],
|
||||
"required": ["name", "date", "participants"],
|
||||
"title": "CalendarEvent",
|
||||
"type": "object",
|
||||
|
@ -79,6 +84,7 @@ def test_completion_pydantic_obj_2():
|
|||
"type": "array",
|
||||
}
|
||||
},
|
||||
"propertyOrdering": ["events"],
|
||||
"required": ["events"],
|
||||
"title": "EventsList",
|
||||
"type": "object",
|
||||
|
|
|
@ -926,12 +926,17 @@ def execute_completion(opts: dict):
|
|||
response_gen = litellm.completion(**opts)
|
||||
for i, part in enumerate(response_gen):
|
||||
partial_streaming_chunks.append(part)
|
||||
print("\n\n")
|
||||
print(f"partial_streaming_chunks: {partial_streaming_chunks}")
|
||||
print("\n\n")
|
||||
assembly = litellm.stream_chunk_builder(partial_streaming_chunks)
|
||||
print(assembly.choices[0].message.tool_calls)
|
||||
print(f"assembly.choices[0].message.tool_calls: {assembly.choices[0].message.tool_calls}")
|
||||
assert len(assembly.choices[0].message.tool_calls) == 3, (
|
||||
assembly.choices[0].message.tool_calls[0].function.arguments[0]
|
||||
)
|
||||
print(assembly.choices[0].message.tool_calls)
|
||||
for tool_call in assembly.choices[0].message.tool_calls:
|
||||
json.loads(tool_call.function.arguments) # assert valid json - https://github.com/BerriAI/litellm/issues/10034
|
||||
|
||||
|
||||
def test_grok_bug(load_env):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue