mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
* fix #9783: Retain schema field ordering for google gemini and vertex (#9828) * test: update test * refactor(groq.py): initial commit migrating groq to base_llm_http_handler * fix(streaming_chunk_builder_utils.py): fix how tool content is combined Fixes https://github.com/BerriAI/litellm/issues/10034 * fix(vertex_ai/common_utils.py): prevent infinite loop in helper function * fix(groq/chat/transformation.py): handle groq streaming errors correctly * fix(groq/chat/transformation.py): handle max_retries --------- Co-authored-by: Adrian Lyjak <adrian@chatmeter.com>
This commit is contained in:
parent
1b9b745cae
commit
fdfa1108a6
12 changed files with 493 additions and 201 deletions
|
@ -106,74 +106,64 @@ class ChunkProcessor:
|
|||
def get_combined_tool_content(
|
||||
self, tool_call_chunks: List[Dict[str, Any]]
|
||||
) -> List[ChatCompletionMessageToolCall]:
|
||||
argument_list: List[str] = []
|
||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
||||
id = None
|
||||
name = None
|
||||
type = None
|
||||
tool_calls_list: List[ChatCompletionMessageToolCall] = []
|
||||
prev_index = None
|
||||
prev_name = None
|
||||
prev_id = None
|
||||
curr_id = None
|
||||
curr_index = 0
|
||||
tool_call_map: Dict[
|
||||
int, Dict[str, Any]
|
||||
] = {} # Map to store tool calls by index
|
||||
|
||||
for chunk in tool_call_chunks:
|
||||
choices = chunk["choices"]
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
tool_calls = delta.get("tool_calls", "")
|
||||
# Check if a tool call is present
|
||||
if tool_calls and tool_calls[0].function is not None:
|
||||
if tool_calls[0].id:
|
||||
id = tool_calls[0].id
|
||||
curr_id = id
|
||||
if prev_id is None:
|
||||
prev_id = curr_id
|
||||
if tool_calls[0].index:
|
||||
curr_index = tool_calls[0].index
|
||||
if tool_calls[0].function.arguments:
|
||||
# Now, tool_calls is expected to be a dictionary
|
||||
arguments = tool_calls[0].function.arguments
|
||||
argument_list.append(arguments)
|
||||
if tool_calls[0].function.name:
|
||||
name = tool_calls[0].function.name
|
||||
if tool_calls[0].type:
|
||||
type = tool_calls[0].type
|
||||
if prev_index is None:
|
||||
prev_index = curr_index
|
||||
if prev_name is None:
|
||||
prev_name = name
|
||||
if curr_index != prev_index: # new tool call
|
||||
combined_arguments = "".join(argument_list)
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not tool_call or not hasattr(tool_call, "function"):
|
||||
continue
|
||||
|
||||
index = getattr(tool_call, "index", 0)
|
||||
if index not in tool_call_map:
|
||||
tool_call_map[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"type": None,
|
||||
"arguments": [],
|
||||
}
|
||||
|
||||
if hasattr(tool_call, "id") and tool_call.id:
|
||||
tool_call_map[index]["id"] = tool_call.id
|
||||
if hasattr(tool_call, "type") and tool_call.type:
|
||||
tool_call_map[index]["type"] = tool_call.type
|
||||
if hasattr(tool_call, "function"):
|
||||
if (
|
||||
hasattr(tool_call.function, "name")
|
||||
and tool_call.function.name
|
||||
):
|
||||
tool_call_map[index]["name"] = tool_call.function.name
|
||||
if (
|
||||
hasattr(tool_call.function, "arguments")
|
||||
and tool_call.function.arguments
|
||||
):
|
||||
tool_call_map[index]["arguments"].append(
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Convert the map to a list of tool calls
|
||||
for index in sorted(tool_call_map.keys()):
|
||||
tool_call_data = tool_call_map[index]
|
||||
if tool_call_data["id"] and tool_call_data["name"]:
|
||||
combined_arguments = "".join(tool_call_data["arguments"]) or "{}"
|
||||
tool_calls_list.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=prev_id,
|
||||
id=tool_call_data["id"],
|
||||
function=Function(
|
||||
arguments=combined_arguments,
|
||||
name=prev_name,
|
||||
name=tool_call_data["name"],
|
||||
),
|
||||
type=type,
|
||||
type=tool_call_data["type"] or "function",
|
||||
index=index,
|
||||
)
|
||||
)
|
||||
argument_list = [] # reset
|
||||
prev_index = curr_index
|
||||
prev_id = curr_id
|
||||
prev_name = name
|
||||
|
||||
combined_arguments = (
|
||||
"".join(argument_list) or "{}"
|
||||
) # base case, return empty dict
|
||||
|
||||
tool_calls_list.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=id,
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=combined_arguments,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls_list
|
||||
|
||||
|
|
|
@ -230,6 +230,7 @@ class BaseLLMHTTPHandler:
|
|||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
|
@ -267,6 +268,9 @@ class BaseLLMHTTPHandler:
|
|||
headers=headers,
|
||||
)
|
||||
|
||||
if extra_body is not None:
|
||||
data = {**data, **extra_body}
|
||||
|
||||
headers = provider_config.sign_request(
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
|
|
|
@ -57,6 +57,14 @@ class GroqChatConfig(OpenAIGPTConfig):
|
|||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = super().get_supported_openai_params(model)
|
||||
try:
|
||||
base_params.remove("max_retries")
|
||||
except ValueError:
|
||||
pass
|
||||
return base_params
|
||||
|
||||
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
|
||||
for idx, message in enumerate(messages):
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -165,9 +165,18 @@ def _check_text_in_content(parts: List[PartType]) -> bool:
|
|||
return has_text_param
|
||||
|
||||
|
||||
def _build_vertex_schema(parameters: dict):
|
||||
def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
|
||||
"""
|
||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||
|
||||
Updates the input parameters, removing extraneous fields, adjusting types, unwinding $defs, and adding propertyOrdering if specified, returning the updated parameters.
|
||||
|
||||
Parameters:
|
||||
parameters: dict - the json schema to build from
|
||||
add_property_ordering: bool - whether to add propertyOrdering to the schema. This is only applicable to schemas for structured outputs. See
|
||||
set_schema_property_ordering for more details.
|
||||
Returns:
|
||||
parameters: dict - the input parameters, modified in place
|
||||
"""
|
||||
# Get valid fields from Schema TypedDict
|
||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||
|
@ -186,8 +195,40 @@ def _build_vertex_schema(parameters: dict):
|
|||
add_object_type(parameters)
|
||||
# Postprocessing
|
||||
# Filter out fields that don't exist in Schema
|
||||
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
return filtered_parameters
|
||||
parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
|
||||
if add_property_ordering:
|
||||
set_schema_property_ordering(parameters)
|
||||
return parameters
|
||||
|
||||
|
||||
def set_schema_property_ordering(
|
||||
schema: Dict[str, Any], depth: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
vertex ai and generativeai apis order output of fields alphabetically, unless you specify the order.
|
||||
python dicts retain order, so we just use that. Note that this field only applies to structured outputs, and not tools.
|
||||
Function tools are not afflicted by the same alphabetical ordering issue, (the order of keys returned seems to be arbitrary, up to the model)
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.cachedContents#Schema.FIELDS.property_ordering
|
||||
|
||||
Args:
|
||||
schema: The schema dictionary to process
|
||||
depth: Current recursion depth to prevent infinite loops
|
||||
"""
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
raise ValueError(
|
||||
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||
)
|
||||
|
||||
if "properties" in schema and isinstance(schema["properties"], dict):
|
||||
# retain propertyOrdering as an escape hatch if user already specifies it
|
||||
if "propertyOrdering" not in schema:
|
||||
schema["propertyOrdering"] = [k for k, v in schema["properties"].items()]
|
||||
for k, v in schema["properties"].items():
|
||||
set_schema_property_ordering(v, depth + 1)
|
||||
if "items" in schema:
|
||||
set_schema_property_ordering(schema["items"], depth + 1)
|
||||
return schema
|
||||
|
||||
|
||||
def filter_schema_fields(
|
||||
|
|
|
@ -207,7 +207,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
"extra_headers",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs", # Added this to list of supported openAI params
|
||||
"top_logprobs",
|
||||
"modalities",
|
||||
]
|
||||
|
||||
|
@ -313,9 +313,10 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
if isinstance(old_schema, list):
|
||||
for item in old_schema:
|
||||
if isinstance(item, dict):
|
||||
item = _build_vertex_schema(parameters=item)
|
||||
item = _build_vertex_schema(parameters=item, add_property_ordering=True)
|
||||
|
||||
elif isinstance(old_schema, dict):
|
||||
old_schema = _build_vertex_schema(parameters=old_schema)
|
||||
old_schema = _build_vertex_schema(parameters=old_schema, add_property_ordering=True)
|
||||
return old_schema
|
||||
|
||||
def apply_response_schema_transformation(self, value: dict, optional_params: dict):
|
||||
|
|
|
@ -1622,24 +1622,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
response = groq_chat_completions.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "aiohttp_openai":
|
||||
# NEW aiohttp provider for 10-100x higher RPS
|
||||
|
@ -2658,9 +2656,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
optional_params[
|
||||
"aws_region_name"
|
||||
] = aws_bedrock_client.meta.region_name
|
||||
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
|
@ -4367,9 +4365,9 @@ def adapter_completion(
|
|||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
|
@ -5789,9 +5787,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(content_chunks) > 0:
|
||||
response["choices"][0]["message"]["content"] = (
|
||||
processor.get_combined_content(content_chunks)
|
||||
)
|
||||
response["choices"][0]["message"][
|
||||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
|
@ -5802,9 +5800,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"]["reasoning_content"] = (
|
||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
)
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue