forked from phoenix/litellm-mirror
Minor LiteLLM Fixes and Improvements (#5456)
* fix(utils.py): support 'drop_params' for embedding requests Fixes https://github.com/BerriAI/litellm/issues/5444 * feat(vertex_ai_non_gemini.py): support function param in messages * test: skip test - model end of life * fix(vertex_ai_non_gemini.py): fix gemini history parsing
This commit is contained in:
parent
54b60a9afd
commit
6fb82aaf75
5 changed files with 248 additions and 65 deletions
|
@ -26,6 +26,12 @@ from litellm.types.completion import (
|
|||
)
|
||||
from litellm.types.llms.anthropic import *
|
||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionFunctionMessage,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolMessage,
|
||||
)
|
||||
from litellm.types.utils import GenericImageParsingChunk
|
||||
|
||||
|
||||
|
@ -964,8 +970,28 @@ def infer_protocol_value(
|
|||
return "unknown"
|
||||
|
||||
|
||||
def _gemini_tool_call_invoke_helper(
|
||||
function_call_params: ChatCompletionToolCallFunctionChunk,
|
||||
) -> Optional[litellm.types.llms.vertex_ai.FunctionCall]:
|
||||
name = function_call_params.get("name", "") or ""
|
||||
arguments = function_call_params.get("arguments", "")
|
||||
arguments_dict = json.loads(arguments)
|
||||
function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = None
|
||||
for k, v in arguments_dict.items():
|
||||
inferred_protocol_value = infer_protocol_value(value=v)
|
||||
_field = litellm.types.llms.vertex_ai.Field(
|
||||
key=k, value={inferred_protocol_value: v}
|
||||
)
|
||||
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
||||
function_call = litellm.types.llms.vertex_ai.FunctionCall(
|
||||
name=name,
|
||||
args=_fields,
|
||||
)
|
||||
return function_call
|
||||
|
||||
|
||||
def convert_to_gemini_tool_call_invoke(
|
||||
tool_calls: list,
|
||||
message: ChatCompletionAssistantMessage,
|
||||
) -> List[litellm.types.llms.vertex_ai.PartType]:
|
||||
"""
|
||||
OpenAI tool invokes:
|
||||
|
@ -1036,49 +1062,55 @@ def convert_to_gemini_tool_call_invoke(
|
|||
"""
|
||||
try:
|
||||
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
name = tool["function"].get("name", "")
|
||||
arguments = tool["function"].get("arguments", "")
|
||||
arguments_dict = json.loads(arguments)
|
||||
function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = (
|
||||
None
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
function_call = message.get("function_call", None)
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
gemini_function_call: Optional[
|
||||
litellm.types.llms.vertex_ai.FunctionCall
|
||||
] = _gemini_tool_call_invoke_helper(
|
||||
function_call_params=tool["function"]
|
||||
)
|
||||
if gemini_function_call is not None:
|
||||
_parts_list.append(
|
||||
litellm.types.llms.vertex_ai.PartType(
|
||||
function_call=gemini_function_call
|
||||
)
|
||||
)
|
||||
else: # don't silently drop params. Make it clear to user what's happening.
|
||||
raise Exception(
|
||||
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
||||
tool
|
||||
)
|
||||
)
|
||||
elif function_call is not None:
|
||||
gemini_function_call = _gemini_tool_call_invoke_helper(
|
||||
function_call_params=function_call
|
||||
)
|
||||
if gemini_function_call is not None:
|
||||
_parts_list.append(
|
||||
litellm.types.llms.vertex_ai.PartType(
|
||||
function_call=gemini_function_call
|
||||
)
|
||||
)
|
||||
for k, v in arguments_dict.items():
|
||||
inferred_protocol_value = infer_protocol_value(value=v)
|
||||
_field = litellm.types.llms.vertex_ai.Field(
|
||||
key=k, value={inferred_protocol_value: v}
|
||||
)
|
||||
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
|
||||
fields=_field
|
||||
)
|
||||
function_call = litellm.types.llms.vertex_ai.FunctionCall(
|
||||
name=name,
|
||||
args=_fields,
|
||||
)
|
||||
if function_call is not None:
|
||||
_parts_list.append(
|
||||
litellm.types.llms.vertex_ai.PartType(
|
||||
function_call=function_call
|
||||
)
|
||||
)
|
||||
else: # don't silently drop params. Make it clear to user what's happening.
|
||||
raise Exception(
|
||||
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
||||
tool
|
||||
)
|
||||
else: # don't silently drop params. Make it clear to user what's happening.
|
||||
raise Exception(
|
||||
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
||||
tool
|
||||
)
|
||||
)
|
||||
return _parts_list
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
|
||||
tool_calls, str(e)
|
||||
message, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_to_gemini_tool_call_result(
|
||||
message: dict,
|
||||
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||
last_message_with_tool_calls: Optional[dict],
|
||||
) -> litellm.types.llms.vertex_ai.PartType:
|
||||
"""
|
||||
|
@ -1098,7 +1130,7 @@ def convert_to_gemini_tool_call_result(
|
|||
}
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
name = ""
|
||||
name: Optional[str] = message.get("name", "") # type: ignore
|
||||
|
||||
# Recover name from last message with tool calls
|
||||
if last_message_with_tool_calls:
|
||||
|
@ -1114,7 +1146,11 @@ def convert_to_gemini_tool_call_result(
|
|||
name = tool.get("function", {}).get("name", "")
|
||||
|
||||
if not name:
|
||||
raise Exception("Missing corresponding tool call for tool response message")
|
||||
raise Exception(
|
||||
"Missing corresponding tool call for tool response message. Received - message={}, last_message_with_tool_calls={}".format(
|
||||
message, last_message_with_tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
|
@ -1127,7 +1163,7 @@ def convert_to_gemini_tool_call_result(
|
|||
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
||||
|
||||
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
|
||||
name=name, response=_function_call_args
|
||||
name=name, response=_function_call_args # type: ignore
|
||||
)
|
||||
|
||||
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
|
||||
|
@ -1782,7 +1818,9 @@ def cohere_messages_pt_v2(
|
|||
assistant_tool_calls: List[ToolCallObject] = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if messages[msg_i].get("content", None) is not None and isinstance(messages[msg_i]["content"], list):
|
||||
if messages[msg_i].get("content", None) is not None and isinstance(
|
||||
messages[msg_i]["content"], list
|
||||
):
|
||||
for m in messages[msg_i]["content"]:
|
||||
if m.get("type", "") == "text":
|
||||
assistant_content += m["text"]
|
||||
|
|
|
@ -1433,7 +1433,7 @@ class VertexLLM(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
if stream is not None and stream is True:
|
||||
if stream is True:
|
||||
request_data_str = json.dumps(data)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
|
|
|
@ -25,6 +25,7 @@ from litellm.types.files import (
|
|||
is_gemini_1_5_accepted_file_type,
|
||||
is_video_file_type,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
|
@ -123,7 +124,9 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
raise e
|
||||
|
||||
|
||||
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||
def _gemini_convert_messages_with_history(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[ContentType]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Gemini format
|
||||
|
||||
|
@ -145,23 +148,26 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
|||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
if messages[msg_i]["content"] is not None and isinstance(
|
||||
messages[msg_i]["content"], list
|
||||
):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
for element in messages[msg_i]["content"]: # type: ignore
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text" and len(element["text"]) > 0:
|
||||
_part = PartType(text=element["text"])
|
||||
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore
|
||||
_part = PartType(text=element["text"]) # type: ignore
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
image_url = element["image_url"]["url"] # type: ignore
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
elif (
|
||||
isinstance(messages[msg_i]["content"], str)
|
||||
and len(messages[msg_i]["content"]) > 0
|
||||
messages[msg_i]["content"] is not None
|
||||
and isinstance(messages[msg_i]["content"], str)
|
||||
and len(messages[msg_i]["content"]) > 0 # type: ignore
|
||||
):
|
||||
_part = PartType(text=messages[msg_i]["content"])
|
||||
_part = PartType(text=messages[msg_i]["content"]) # type: ignore
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
@ -175,31 +181,34 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
|||
messages[msg_i]["content"], list
|
||||
):
|
||||
_parts = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
for element in messages[msg_i]["content"]: # type: ignore
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_part = PartType(text=element["text"]) # type: ignore
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
image_url = element["image_url"]["url"] # type: ignore
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
assistant_content.extend(_parts)
|
||||
elif (
|
||||
messages[msg_i].get("content", None) is not None
|
||||
and isinstance(messages[msg_i]["content"], str)
|
||||
and messages[msg_i]["content"]
|
||||
):
|
||||
assistant_text = messages[msg_i]["content"] # either string or none
|
||||
assistant_content.append(PartType(text=assistant_text)) # type: ignore
|
||||
elif messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke conversion
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(
|
||||
messages[msg_i]["tool_calls"]
|
||||
)
|
||||
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
|
||||
)
|
||||
last_message_with_tool_calls = messages[msg_i]
|
||||
else:
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(PartType(text=assistant_text))
|
||||
elif messages[msg_i].get("function_call") is not None:
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
|
@ -207,12 +216,16 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
|||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
if msg_i < len(messages) and (
|
||||
messages[msg_i]["role"] == "tool"
|
||||
or messages[msg_i]["role"] == "function"
|
||||
):
|
||||
_part = convert_to_gemini_tool_call_result(
|
||||
messages[msg_i], last_message_with_tool_calls
|
||||
messages[msg_i], last_message_with_tool_calls # type: ignore
|
||||
)
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
|
|
|
@ -906,6 +906,7 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode):
|
|||
"tools": tools,
|
||||
"tool_choice": "required",
|
||||
}
|
||||
print(f"Model for call - {model}")
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
else:
|
||||
|
@ -2630,6 +2631,129 @@ async def test_partner_models_httpx_ai21():
|
|||
|
||||
print(f"response: {response}")
|
||||
|
||||
print("hidden params from response=", response._hidden_params)
|
||||
|
||||
assert response._hidden_params["response_cost"] > 0
|
||||
def test_gemini_function_call_parameter_in_messages():
|
||||
litellm.set_verbose = True
|
||||
load_vertex_ai_credentials()
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Executes searches.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"description": "A list of queries to search for.",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["queries"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Set up the messages
|
||||
messages = [
|
||||
{"role": "system", "content": """Use search for most queries."""},
|
||||
{"role": "user", "content": """search for weather in boston (use `search`)"""},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": "search",
|
||||
"arguments": '{"queries": ["weather in boston"]}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"name": "search",
|
||||
"content": "The current weather in Boston is 22°F.",
|
||||
},
|
||||
]
|
||||
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||
try:
|
||||
response_stream = completion(
|
||||
model="vertex_ai/gemini-1.5-pro",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# mock_client.assert_any_call()
|
||||
assert {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "search for weather in boston (use `search`)"}],
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search",
|
||||
"args": {
|
||||
"fields": {
|
||||
"key": "queries",
|
||||
"value": {"list_value": ["weather in boston"]},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search",
|
||||
"response": {
|
||||
"fields": {
|
||||
"key": "content",
|
||||
"value": {
|
||||
"string_value": "The current weather in Boston is 22°F."
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"system_instruction": {"parts": [{"text": "Use search for most queries."}]},
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Executes searches.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"description": "A list of queries to search for.",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
"required": ["queries"],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"toolConfig": {"functionCallingConfig": {"mode": "AUTO"}},
|
||||
"generationConfig": {},
|
||||
} == mock_client.call_args.kwargs["json"]
|
||||
|
|
|
@ -364,8 +364,9 @@ class ChatCompletionUserMessage(TypedDict):
|
|||
class ChatCompletionAssistantMessage(TypedDict, total=False):
|
||||
role: Required[Literal["assistant"]]
|
||||
content: Optional[str]
|
||||
name: str
|
||||
tool_calls: List[ChatCompletionAssistantToolCall]
|
||||
name: Optional[str]
|
||||
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
|
||||
function_call: Optional[ChatCompletionToolCallFunctionChunk]
|
||||
|
||||
|
||||
class ChatCompletionToolMessage(TypedDict):
|
||||
|
@ -374,6 +375,12 @@ class ChatCompletionToolMessage(TypedDict):
|
|||
tool_call_id: str
|
||||
|
||||
|
||||
class ChatCompletionFunctionMessage(TypedDict):
|
||||
role: Literal["function"]
|
||||
content: Optional[str]
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionSystemMessage(TypedDict, total=False):
|
||||
role: Required[Literal["system"]]
|
||||
content: Required[Union[str, List]]
|
||||
|
@ -385,6 +392,7 @@ AllMessageValues = Union[
|
|||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionToolMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionFunctionMessage,
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue