mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* fix #9783: Retain schema field ordering for google gemini and vertex (#9828) * test: update test * refactor(groq.py): initial commit migrating groq to base_llm_http_handler * fix(streaming_chunk_builder_utils.py): fix how tool content is combined Fixes https://github.com/BerriAI/litellm/issues/10034 * fix(vertex_ai/common_utils.py): prevent infinite loop in helper function * fix(groq/chat/transformation.py): handle groq streaming errors correctly * fix(groq/chat/transformation.py): handle max_retries --------- Co-authored-by: Adrian Lyjak <adrian@chatmeter.com>
This commit is contained in:
parent
1b9b745cae
commit
fdfa1108a6
12 changed files with 493 additions and 201 deletions
|
@ -106,74 +106,64 @@ class ChunkProcessor:
|
||||||
def get_combined_tool_content(
|
def get_combined_tool_content(
|
||||||
self, tool_call_chunks: List[Dict[str, Any]]
|
self, tool_call_chunks: List[Dict[str, Any]]
|
||||||
) -> List[ChatCompletionMessageToolCall]:
|
) -> List[ChatCompletionMessageToolCall]:
|
||||||
argument_list: List[str] = []
|
|
||||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
|
||||||
id = None
|
|
||||||
name = None
|
|
||||||
type = None
|
|
||||||
tool_calls_list: List[ChatCompletionMessageToolCall] = []
|
tool_calls_list: List[ChatCompletionMessageToolCall] = []
|
||||||
prev_index = None
|
tool_call_map: Dict[
|
||||||
prev_name = None
|
int, Dict[str, Any]
|
||||||
prev_id = None
|
] = {} # Map to store tool calls by index
|
||||||
curr_id = None
|
|
||||||
curr_index = 0
|
|
||||||
for chunk in tool_call_chunks:
|
for chunk in tool_call_chunks:
|
||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
delta = choice.get("delta", {})
|
delta = choice.get("delta", {})
|
||||||
tool_calls = delta.get("tool_calls", "")
|
tool_calls = delta.get("tool_calls", [])
|
||||||
# Check if a tool call is present
|
|
||||||
if tool_calls and tool_calls[0].function is not None:
|
for tool_call in tool_calls:
|
||||||
if tool_calls[0].id:
|
if not tool_call or not hasattr(tool_call, "function"):
|
||||||
id = tool_calls[0].id
|
continue
|
||||||
curr_id = id
|
|
||||||
if prev_id is None:
|
index = getattr(tool_call, "index", 0)
|
||||||
prev_id = curr_id
|
if index not in tool_call_map:
|
||||||
if tool_calls[0].index:
|
tool_call_map[index] = {
|
||||||
curr_index = tool_calls[0].index
|
"id": None,
|
||||||
if tool_calls[0].function.arguments:
|
"name": None,
|
||||||
# Now, tool_calls is expected to be a dictionary
|
"type": None,
|
||||||
arguments = tool_calls[0].function.arguments
|
"arguments": [],
|
||||||
argument_list.append(arguments)
|
}
|
||||||
if tool_calls[0].function.name:
|
|
||||||
name = tool_calls[0].function.name
|
if hasattr(tool_call, "id") and tool_call.id:
|
||||||
if tool_calls[0].type:
|
tool_call_map[index]["id"] = tool_call.id
|
||||||
type = tool_calls[0].type
|
if hasattr(tool_call, "type") and tool_call.type:
|
||||||
if prev_index is None:
|
tool_call_map[index]["type"] = tool_call.type
|
||||||
prev_index = curr_index
|
if hasattr(tool_call, "function"):
|
||||||
if prev_name is None:
|
if (
|
||||||
prev_name = name
|
hasattr(tool_call.function, "name")
|
||||||
if curr_index != prev_index: # new tool call
|
and tool_call.function.name
|
||||||
combined_arguments = "".join(argument_list)
|
):
|
||||||
|
tool_call_map[index]["name"] = tool_call.function.name
|
||||||
|
if (
|
||||||
|
hasattr(tool_call.function, "arguments")
|
||||||
|
and tool_call.function.arguments
|
||||||
|
):
|
||||||
|
tool_call_map[index]["arguments"].append(
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the map to a list of tool calls
|
||||||
|
for index in sorted(tool_call_map.keys()):
|
||||||
|
tool_call_data = tool_call_map[index]
|
||||||
|
if tool_call_data["id"] and tool_call_data["name"]:
|
||||||
|
combined_arguments = "".join(tool_call_data["arguments"]) or "{}"
|
||||||
tool_calls_list.append(
|
tool_calls_list.append(
|
||||||
ChatCompletionMessageToolCall(
|
ChatCompletionMessageToolCall(
|
||||||
id=prev_id,
|
id=tool_call_data["id"],
|
||||||
function=Function(
|
function=Function(
|
||||||
arguments=combined_arguments,
|
arguments=combined_arguments,
|
||||||
name=prev_name,
|
name=tool_call_data["name"],
|
||||||
),
|
),
|
||||||
type=type,
|
type=tool_call_data["type"] or "function",
|
||||||
|
index=index,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
argument_list = [] # reset
|
|
||||||
prev_index = curr_index
|
|
||||||
prev_id = curr_id
|
|
||||||
prev_name = name
|
|
||||||
|
|
||||||
combined_arguments = (
|
|
||||||
"".join(argument_list) or "{}"
|
|
||||||
) # base case, return empty dict
|
|
||||||
|
|
||||||
tool_calls_list.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=id,
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
arguments=combined_arguments,
|
|
||||||
name=name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool_calls_list
|
return tool_calls_list
|
||||||
|
|
||||||
|
|
|
@ -230,6 +230,7 @@ class BaseLLMHTTPHandler:
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
json_mode: bool = optional_params.pop("json_mode", False)
|
json_mode: bool = optional_params.pop("json_mode", False)
|
||||||
|
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||||
|
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
@ -267,6 +268,9 @@ class BaseLLMHTTPHandler:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extra_body is not None:
|
||||||
|
data = {**data, **extra_body}
|
||||||
|
|
||||||
headers = provider_config.sign_request(
|
headers = provider_config.sign_request(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
|
|
@ -57,6 +57,14 @@ class GroqChatConfig(OpenAIGPTConfig):
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return super().get_config()
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
base_params = super().get_supported_openai_params(model)
|
||||||
|
try:
|
||||||
|
base_params.remove("max_retries")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return base_params
|
||||||
|
|
||||||
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
|
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
|
||||||
for idx, message in enumerate(messages):
|
for idx, message in enumerate(messages):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -165,9 +165,18 @@ def _check_text_in_content(parts: List[PartType]) -> bool:
|
||||||
return has_text_param
|
return has_text_param
|
||||||
|
|
||||||
|
|
||||||
def _build_vertex_schema(parameters: dict):
|
def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
|
||||||
"""
|
"""
|
||||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||||
|
|
||||||
|
Updates the input parameters, removing extraneous fields, adjusting types, unwinding $defs, and adding propertyOrdering if specified, returning the updated parameters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
parameters: dict - the json schema to build from
|
||||||
|
add_property_ordering: bool - whether to add propertyOrdering to the schema. This is only applicable to schemas for structured outputs. See
|
||||||
|
set_schema_property_ordering for more details.
|
||||||
|
Returns:
|
||||||
|
parameters: dict - the input parameters, modified in place
|
||||||
"""
|
"""
|
||||||
# Get valid fields from Schema TypedDict
|
# Get valid fields from Schema TypedDict
|
||||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||||
|
@ -186,8 +195,40 @@ def _build_vertex_schema(parameters: dict):
|
||||||
add_object_type(parameters)
|
add_object_type(parameters)
|
||||||
# Postprocessing
|
# Postprocessing
|
||||||
# Filter out fields that don't exist in Schema
|
# Filter out fields that don't exist in Schema
|
||||||
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||||
return filtered_parameters
|
|
||||||
|
if add_property_ordering:
|
||||||
|
set_schema_property_ordering(parameters)
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
def set_schema_property_ordering(
|
||||||
|
schema: Dict[str, Any], depth: int = 0
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
vertex ai and generativeai apis order output of fields alphabetically, unless you specify the order.
|
||||||
|
python dicts retain order, so we just use that. Note that this field only applies to structured outputs, and not tools.
|
||||||
|
Function tools are not afflicted by the same alphabetical ordering issue, (the order of keys returned seems to be arbitrary, up to the model)
|
||||||
|
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.cachedContents#Schema.FIELDS.property_ordering
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: The schema dictionary to process
|
||||||
|
depth: Current recursion depth to prevent infinite loops
|
||||||
|
"""
|
||||||
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "properties" in schema and isinstance(schema["properties"], dict):
|
||||||
|
# retain propertyOrdering as an escape hatch if user already specifies it
|
||||||
|
if "propertyOrdering" not in schema:
|
||||||
|
schema["propertyOrdering"] = [k for k, v in schema["properties"].items()]
|
||||||
|
for k, v in schema["properties"].items():
|
||||||
|
set_schema_property_ordering(v, depth + 1)
|
||||||
|
if "items" in schema:
|
||||||
|
set_schema_property_ordering(schema["items"], depth + 1)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def filter_schema_fields(
|
def filter_schema_fields(
|
||||||
|
|
|
@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
"seed",
|
"seed",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"top_logprobs", # Added this to list of supported openAI params
|
"top_logprobs",
|
||||||
"modalities",
|
"modalities",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -313,9 +313,10 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
if isinstance(old_schema, list):
|
if isinstance(old_schema, list):
|
||||||
for item in old_schema:
|
for item in old_schema:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
item = _build_vertex_schema(parameters=item)
|
item = _build_vertex_schema(parameters=item, add_property_ordering=True)
|
||||||
|
|
||||||
elif isinstance(old_schema, dict):
|
elif isinstance(old_schema, dict):
|
||||||
old_schema = _build_vertex_schema(parameters=old_schema)
|
old_schema = _build_vertex_schema(parameters=old_schema, add_property_ordering=True)
|
||||||
return old_schema
|
return old_schema
|
||||||
|
|
||||||
def apply_response_schema_transformation(self, value: dict, optional_params: dict):
|
def apply_response_schema_transformation(self, value: dict, optional_params: dict):
|
||||||
|
|
|
@ -1622,24 +1622,22 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
response = groq_chat_completions.completion(
|
response = base_llm_http_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
stream=stream,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
headers=headers,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
api_base=api_base,
|
||||||
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
|
||||||
timeout=timeout, # type: ignore
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "aiohttp_openai":
|
elif custom_llm_provider == "aiohttp_openai":
|
||||||
# NEW aiohttp provider for 10-100x higher RPS
|
# NEW aiohttp provider for 10-100x higher RPS
|
||||||
|
@ -2658,9 +2656,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"aws_region_name" not in optional_params
|
"aws_region_name" not in optional_params
|
||||||
or optional_params["aws_region_name"] is None
|
or optional_params["aws_region_name"] is None
|
||||||
):
|
):
|
||||||
optional_params["aws_region_name"] = (
|
optional_params[
|
||||||
aws_bedrock_client.meta.region_name
|
"aws_region_name"
|
||||||
)
|
] = aws_bedrock_client.meta.region_name
|
||||||
|
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
if bedrock_route == "converse":
|
if bedrock_route == "converse":
|
||||||
|
@ -4367,9 +4365,9 @@ def adapter_completion(
|
||||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||||
|
|
||||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
translated_response: Optional[
|
||||||
None
|
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||||
)
|
] = None
|
||||||
if isinstance(response, ModelResponse):
|
if isinstance(response, ModelResponse):
|
||||||
translated_response = translation_obj.translate_completion_output_params(
|
translated_response = translation_obj.translate_completion_output_params(
|
||||||
response=response
|
response=response
|
||||||
|
@ -5789,9 +5787,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(content_chunks) > 0:
|
if len(content_chunks) > 0:
|
||||||
response["choices"][0]["message"]["content"] = (
|
response["choices"][0]["message"][
|
||||||
processor.get_combined_content(content_chunks)
|
"content"
|
||||||
)
|
] = processor.get_combined_content(content_chunks)
|
||||||
|
|
||||||
reasoning_chunks = [
|
reasoning_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5802,9 +5800,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(reasoning_chunks) > 0:
|
if len(reasoning_chunks) > 0:
|
||||||
response["choices"][0]["message"]["reasoning_content"] = (
|
response["choices"][0]["message"][
|
||||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
"reasoning_content"
|
||||||
)
|
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||||
|
|
||||||
audio_chunks = [
|
audio_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
|
|
@ -18,6 +18,7 @@ IGNORE_FUNCTIONS = [
|
||||||
"_serialize", # we now set a max depth for this
|
"_serialize", # we now set a max depth for this
|
||||||
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
|
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
|
||||||
"_sanitize_value", # testing added for circular reference
|
"_sanitize_value", # testing added for circular reference
|
||||||
|
"set_schema_property_ordering", # testing added for infinite recursion
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionDeltaToolCall,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
Delta,
|
||||||
|
Function,
|
||||||
|
ModelResponseStream,
|
||||||
|
StreamingChoices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_combined_tool_content():
|
||||||
|
chunks = [
|
||||||
|
ModelResponseStream(
|
||||||
|
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||||
|
created=1744771912,
|
||||||
|
model="llama-3.3-70b-versatile",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
system_fingerprint=None,
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
delta=Delta(
|
||||||
|
provider_specific_fields=None,
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="call_m87w",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "San Francisco", "unit": "imperial"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
audio=None,
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_specific_fields=None,
|
||||||
|
stream_options=None,
|
||||||
|
),
|
||||||
|
ModelResponseStream(
|
||||||
|
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||||
|
created=1744771912,
|
||||||
|
model="llama-3.3-70b-versatile",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
system_fingerprint=None,
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
delta=Delta(
|
||||||
|
provider_specific_fields=None,
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="call_rrns",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "Tokyo", "unit": "metric"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=1,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
audio=None,
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_specific_fields=None,
|
||||||
|
stream_options=None,
|
||||||
|
),
|
||||||
|
ModelResponseStream(
|
||||||
|
id="chatcmpl-8478099a-3724-42c7-9194-88d97ffd254b",
|
||||||
|
created=1744771912,
|
||||||
|
model="llama-3.3-70b-versatile",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
system_fingerprint=None,
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
delta=Delta(
|
||||||
|
provider_specific_fields=None,
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="call_0k29",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "Paris", "unit": "metric"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=2,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
audio=None,
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_specific_fields=None,
|
||||||
|
stream_options=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
chunk_processor = ChunkProcessor(chunks=chunks)
|
||||||
|
|
||||||
|
tool_calls_list = chunk_processor.get_combined_tool_content(chunks)
|
||||||
|
assert tool_calls_list == [
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="call_m87w",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "San Francisco", "unit": "imperial"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=0,
|
||||||
|
),
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="call_rrns",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "Tokyo", "unit": "metric"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=1,
|
||||||
|
),
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="call_0k29",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"location": "Paris", "unit": "metric"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=2,
|
||||||
|
),
|
||||||
|
]
|
|
@ -9,6 +9,8 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import ChoiceLogprobs
|
from litellm.types.utils import ChoiceLogprobs
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, cast
|
||||||
|
|
||||||
|
|
||||||
def test_top_logprobs():
|
def test_top_logprobs():
|
||||||
|
@ -62,3 +64,160 @@ def test_get_model_name_from_gemini_spec_model():
|
||||||
model = "gemini/ft-uuid-123"
|
model = "gemini/ft-uuid-123"
|
||||||
result = VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
result = VertexGeminiConfig._get_model_name_from_gemini_spec_model(model)
|
||||||
assert result == "ft-uuid-123"
|
assert result == "ft-uuid-123"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_ai_response_schema_dict():
|
||||||
|
v = VertexGeminiConfig()
|
||||||
|
transformed_request = v.map_openai_params(
|
||||||
|
non_default_params={
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "math_reasoning",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"steps": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"thought": {"type": "string"},
|
||||||
|
"output": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["thought", "output"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"final_answer": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["steps", "final_answer"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
"strict": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
optional_params={},
|
||||||
|
model="gemini-2.0-flash-lite",
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = transformed_request["response_schema"]
|
||||||
|
# should add propertyOrdering
|
||||||
|
assert schema["propertyOrdering"] == ["steps", "final_answer"]
|
||||||
|
# should add propertyOrdering (recursively, including array items)
|
||||||
|
assert schema["properties"]["steps"]["items"]["propertyOrdering"] == [
|
||||||
|
"thought",
|
||||||
|
"output",
|
||||||
|
]
|
||||||
|
# should strip strict and additionalProperties
|
||||||
|
assert "strict" not in schema
|
||||||
|
assert "additionalProperties" not in schema
|
||||||
|
# validate the whole thing to catch regressions
|
||||||
|
assert transformed_request["response_schema"] == {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"steps": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"thought": {"type": "string"},
|
||||||
|
"output": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["thought", "output"],
|
||||||
|
"propertyOrdering": ["thought", "output"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"final_answer": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["steps", "final_answer"],
|
||||||
|
"propertyOrdering": ["steps", "final_answer"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MathReasoning(BaseModel):
|
||||||
|
steps: List["Step"]
|
||||||
|
final_answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class Step(BaseModel):
|
||||||
|
thought: str
|
||||||
|
output: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_ai_response_schema_defs():
|
||||||
|
v = VertexGeminiConfig()
|
||||||
|
|
||||||
|
schema = cast(dict, v.get_json_schema_from_pydantic_object(MathReasoning))
|
||||||
|
|
||||||
|
# pydantic conversion by default adds $defs to the schema, make sure this is still the case, otherwise this test isn't really testing anything
|
||||||
|
assert "$defs" in schema["json_schema"]["schema"]
|
||||||
|
|
||||||
|
transformed_request = v.map_openai_params(
|
||||||
|
non_default_params={
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"response_format": schema,
|
||||||
|
},
|
||||||
|
optional_params={},
|
||||||
|
model="gemini-2.0-flash-lite",
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "$defs" not in transformed_request["response_schema"]
|
||||||
|
assert transformed_request["response_schema"] == {
|
||||||
|
"title": "MathReasoning",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"steps": {
|
||||||
|
"title": "Steps",
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"title": "Step",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"thought": {"title": "Thought", "type": "string"},
|
||||||
|
"output": {"title": "Output", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["thought", "output"],
|
||||||
|
"propertyOrdering": ["thought", "output"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"final_answer": {"title": "Final Answer", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["steps", "final_answer"],
|
||||||
|
"propertyOrdering": ["steps", "final_answer"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_ai_retain_property_ordering():
|
||||||
|
v = VertexGeminiConfig()
|
||||||
|
transformed_request = v.map_openai_params(
|
||||||
|
non_default_params={
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "math_reasoning",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"output": {"type": "string"},
|
||||||
|
"thought": {"type": "string"},
|
||||||
|
},
|
||||||
|
"propertyOrdering": ["thought", "output"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
optional_params={},
|
||||||
|
model="gemini-2.0-flash-lite",
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = transformed_request["response_schema"]
|
||||||
|
# should leave existing value alone, despite dictionary ordering
|
||||||
|
assert schema["propertyOrdering"] == ["thought", "output"]
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../../..")
|
0, os.path.abspath("../../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.vertex_ai.common_utils import (
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
|
convert_anyof_null_to_nullable,
|
||||||
get_vertex_location_from_url,
|
get_vertex_location_from_url,
|
||||||
get_vertex_project_id_from_url,
|
get_vertex_project_id_from_url,
|
||||||
convert_anyof_null_to_nullable
|
set_schema_property_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,31 +47,19 @@ async def test_get_vertex_location_from_url():
|
||||||
location = get_vertex_location_from_url(url)
|
location = get_vertex_location_from_url(url)
|
||||||
assert location is None
|
assert location is None
|
||||||
|
|
||||||
|
|
||||||
def test_basic_anyof_conversion():
|
def test_basic_anyof_conversion():
|
||||||
"""Test basic conversion of anyOf with 'null'."""
|
"""Test basic conversion of anyOf with 'null'."""
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"example": {"anyOf": [{"type": "string"}, {"type": "null"}]}},
|
||||||
"example": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string"},
|
|
||||||
{"type": "null"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
convert_anyof_null_to_nullable(schema)
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"example": {"anyOf": [{"type": "string", "nullable": True}]}},
|
||||||
"example": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string", "nullable": True}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
assert schema == expected
|
assert schema == expected
|
||||||
|
|
||||||
|
@ -85,12 +76,12 @@ def test_nested_anyof_conversion():
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{"type": "array", "items": {"type": "string"}},
|
{"type": "array", "items": {"type": "string"}},
|
||||||
{"type": "string"},
|
{"type": "string"},
|
||||||
{"type": "null"}
|
{"type": "null"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
convert_anyof_null_to_nullable(schema)
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
@ -103,16 +94,21 @@ def test_nested_anyof_conversion():
|
||||||
"properties": {
|
"properties": {
|
||||||
"inner": {
|
"inner": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
{
|
||||||
{"type": "string", "nullable": True}
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"nullable": True,
|
||||||
|
},
|
||||||
|
{"type": "string", "nullable": True},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
assert schema == expected
|
assert schema == expected
|
||||||
|
|
||||||
|
|
||||||
def test_anyof_with_excessive_nesting():
|
def test_anyof_with_excessive_nesting():
|
||||||
"""Test conversion with excessive nesting > max levels +1 deep."""
|
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||||
# generate a schema with excessive nesting
|
# generate a schema with excessive nesting
|
||||||
|
@ -121,21 +117,19 @@ def test_anyof_with_excessive_nesting():
|
||||||
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||||
current["properties"] = {
|
current["properties"] = {
|
||||||
"nested": {
|
"nested": {
|
||||||
"anyOf": [
|
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||||
{"type": "string"},
|
"properties": {},
|
||||||
{"type": "null"}
|
|
||||||
],
|
|
||||||
"properties": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
current = current["properties"]["nested"]
|
current = current["properties"]["nested"]
|
||||||
|
|
||||||
|
|
||||||
# running the conversion will raise an error
|
# running the conversion will raise an error
|
||||||
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting.",
|
||||||
|
):
|
||||||
convert_anyof_null_to_nullable(schema)
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_supports_system_message():
|
async def test_get_supports_system_message():
|
||||||
|
@ -153,93 +147,20 @@ async def test_get_supports_system_message():
|
||||||
model="random-model-name", custom_llm_provider="vertex_ai"
|
model="random-model-name", custom_llm_provider="vertex_ai"
|
||||||
)
|
)
|
||||||
assert result == False
|
assert result == False
|
||||||
def test_basic_anyof_conversion():
|
|
||||||
"""Test basic conversion of anyOf with 'null'."""
|
|
||||||
schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"example": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string"},
|
|
||||||
{"type": "null"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
convert_anyof_null_to_nullable(schema)
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"example": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string", "nullable": True}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert schema == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_anyof_conversion():
|
def test_set_schema_property_ordering_with_excessive_nesting():
|
||||||
"""Test nested conversion with 'anyOf' inside properties."""
|
"""Test set_schema_property_ordering with excessive nesting > max levels +1 deep."""
|
||||||
schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"outer": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"inner": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "array", "items": {"type": "string"}},
|
|
||||||
{"type": "string"},
|
|
||||||
{"type": "null"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
convert_anyof_null_to_nullable(schema)
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"outer": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"inner": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
|
||||||
{"type": "string", "nullable": True}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert schema == expected
|
|
||||||
|
|
||||||
def test_anyof_with_excessive_nesting():
|
|
||||||
"""Test conversion with excessive nesting > max levels +1 deep."""
|
|
||||||
# generate a schema with excessive nesting
|
# generate a schema with excessive nesting
|
||||||
schema = {"type": "object", "properties": {}}
|
schema = {"type": "object", "properties": {}}
|
||||||
current = schema
|
current = schema
|
||||||
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||||
current["properties"] = {
|
current["properties"] = {"nested": {"type": "object", "properties": {}}}
|
||||||
"nested": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string"},
|
|
||||||
{"type": "null"}
|
|
||||||
],
|
|
||||||
"properties": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
current = current["properties"]["nested"]
|
current = current["properties"]["nested"]
|
||||||
|
|
||||||
|
|
||||||
# running the conversion will raise an error
|
# running the function will raise an error
|
||||||
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
with pytest.raises(
|
||||||
convert_anyof_null_to_nullable(schema)
|
ValueError,
|
||||||
|
match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting.",
|
||||||
|
):
|
||||||
|
set_schema_property_ordering(schema)
|
||||||
|
|
|
@ -71,6 +71,11 @@ def test_completion_pydantic_obj_2():
|
||||||
"type": "array",
|
"type": "array",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"propertyOrdering": [
|
||||||
|
"name",
|
||||||
|
"date",
|
||||||
|
"participants",
|
||||||
|
],
|
||||||
"required": ["name", "date", "participants"],
|
"required": ["name", "date", "participants"],
|
||||||
"title": "CalendarEvent",
|
"title": "CalendarEvent",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -79,6 +84,7 @@ def test_completion_pydantic_obj_2():
|
||||||
"type": "array",
|
"type": "array",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"propertyOrdering": ["events"],
|
||||||
"required": ["events"],
|
"required": ["events"],
|
||||||
"title": "EventsList",
|
"title": "EventsList",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
|
@ -926,12 +926,17 @@ def execute_completion(opts: dict):
|
||||||
response_gen = litellm.completion(**opts)
|
response_gen = litellm.completion(**opts)
|
||||||
for i, part in enumerate(response_gen):
|
for i, part in enumerate(response_gen):
|
||||||
partial_streaming_chunks.append(part)
|
partial_streaming_chunks.append(part)
|
||||||
|
print("\n\n")
|
||||||
|
print(f"partial_streaming_chunks: {partial_streaming_chunks}")
|
||||||
|
print("\n\n")
|
||||||
assembly = litellm.stream_chunk_builder(partial_streaming_chunks)
|
assembly = litellm.stream_chunk_builder(partial_streaming_chunks)
|
||||||
print(assembly.choices[0].message.tool_calls)
|
print(f"assembly.choices[0].message.tool_calls: {assembly.choices[0].message.tool_calls}")
|
||||||
assert len(assembly.choices[0].message.tool_calls) == 3, (
|
assert len(assembly.choices[0].message.tool_calls) == 3, (
|
||||||
assembly.choices[0].message.tool_calls[0].function.arguments[0]
|
assembly.choices[0].message.tool_calls[0].function.arguments[0]
|
||||||
)
|
)
|
||||||
print(assembly.choices[0].message.tool_calls)
|
print(assembly.choices[0].message.tool_calls)
|
||||||
|
for tool_call in assembly.choices[0].message.tool_calls:
|
||||||
|
json.loads(tool_call.function.arguments) # assert valid json - https://github.com/BerriAI/litellm/issues/10034
|
||||||
|
|
||||||
|
|
||||||
def test_grok_bug(load_env):
|
def test_grok_bug(load_env):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue