Minor LiteLLM Fixes and Improvements (#5456)

* fix(utils.py): support 'drop_params' for embedding requests

Fixes https://github.com/BerriAI/litellm/issues/5444

* feat(vertex_ai_non_gemini.py): support function param in messages

* test: skip test - model end of life

* fix(vertex_ai_non_gemini.py): fix gemini history parsing
This commit is contained in:
Krish Dholakia 2024-08-31 17:58:10 -07:00 committed by GitHub
parent 54b60a9afd
commit 6fb82aaf75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 248 additions and 65 deletions

View file

@ -26,6 +26,12 @@ from litellm.types.completion import (
)
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.openai import (
ChatCompletionAssistantMessage,
ChatCompletionFunctionMessage,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage,
)
from litellm.types.utils import GenericImageParsingChunk
@ -964,8 +970,28 @@ def infer_protocol_value(
return "unknown"
def _gemini_tool_call_invoke_helper(
function_call_params: ChatCompletionToolCallFunctionChunk,
) -> Optional[litellm.types.llms.vertex_ai.FunctionCall]:
name = function_call_params.get("name", "") or ""
arguments = function_call_params.get("arguments", "")
arguments_dict = json.loads(arguments)
function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = None
for k, v in arguments_dict.items():
inferred_protocol_value = infer_protocol_value(value=v)
_field = litellm.types.llms.vertex_ai.Field(
key=k, value={inferred_protocol_value: v}
)
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
function_call = litellm.types.llms.vertex_ai.FunctionCall(
name=name,
args=_fields,
)
return function_call
def convert_to_gemini_tool_call_invoke(
tool_calls: list,
message: ChatCompletionAssistantMessage,
) -> List[litellm.types.llms.vertex_ai.PartType]:
"""
OpenAI tool invokes:
@ -1036,49 +1062,55 @@ def convert_to_gemini_tool_call_invoke(
"""
try:
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
for tool in tool_calls:
if "function" in tool:
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = (
None
tool_calls = message.get("tool_calls", None)
function_call = message.get("function_call", None)
if tool_calls is not None:
for tool in tool_calls:
if "function" in tool:
gemini_function_call: Optional[
litellm.types.llms.vertex_ai.FunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
if gemini_function_call is not None:
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(
function_call=gemini_function_call
)
)
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
tool
)
)
elif function_call is not None:
gemini_function_call = _gemini_tool_call_invoke_helper(
function_call_params=function_call
)
if gemini_function_call is not None:
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(
function_call=gemini_function_call
)
)
for k, v in arguments_dict.items():
inferred_protocol_value = infer_protocol_value(value=v)
_field = litellm.types.llms.vertex_ai.Field(
key=k, value={inferred_protocol_value: v}
)
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
fields=_field
)
function_call = litellm.types.llms.vertex_ai.FunctionCall(
name=name,
args=_fields,
)
if function_call is not None:
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(
function_call=function_call
)
)
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
tool
)
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
tool
)
)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
tool_calls, str(e)
message, str(e)
)
)
def convert_to_gemini_tool_call_result(
message: dict,
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
last_message_with_tool_calls: Optional[dict],
) -> litellm.types.llms.vertex_ai.PartType:
"""
@ -1098,7 +1130,7 @@ def convert_to_gemini_tool_call_result(
}
"""
content = message.get("content", "")
name = ""
name: Optional[str] = message.get("name", "") # type: ignore
# Recover name from last message with tool calls
if last_message_with_tool_calls:
@ -1114,7 +1146,11 @@ def convert_to_gemini_tool_call_result(
name = tool.get("function", {}).get("name", "")
if not name:
raise Exception("Missing corresponding tool call for tool response message")
raise Exception(
"Missing corresponding tool call for tool response message. Received - message={}, last_message_with_tool_calls={}".format(
message, last_message_with_tool_calls
)
)
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
@ -1127,7 +1163,7 @@ def convert_to_gemini_tool_call_result(
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
name=name, response=_function_call_args
name=name, response=_function_call_args # type: ignore
)
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
@ -1782,7 +1818,9 @@ def cohere_messages_pt_v2(
assistant_tool_calls: List[ToolCallObject] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if messages[msg_i].get("content", None) is not None and isinstance(messages[msg_i]["content"], list):
if messages[msg_i].get("content", None) is not None and isinstance(
messages[msg_i]["content"], list
):
for m in messages[msg_i]["content"]:
if m.get("type", "") == "text":
assistant_content += m["text"]

View file

@ -1433,7 +1433,7 @@ class VertexLLM(BaseLLM):
},
)
if stream is not None and stream is True:
if stream is True:
request_data_str = json.dumps(data)
streaming_response = CustomStreamWrapper(
completion_stream=None,

View file

@ -25,6 +25,7 @@ from litellm.types.files import (
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -123,7 +124,9 @@ def _process_gemini_image(image_url: str) -> PartType:
raise e
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
def _gemini_convert_messages_with_history(
messages: List[AllMessageValues],
) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
@ -145,23 +148,26 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
if isinstance(messages[msg_i]["content"], list):
if messages[msg_i]["content"] is not None and isinstance(
messages[msg_i]["content"], list
):
_parts: List[PartType] = []
for element in messages[msg_i]["content"]:
for element in messages[msg_i]["content"]: # type: ignore
if isinstance(element, dict):
if element["type"] == "text" and len(element["text"]) > 0:
_part = PartType(text=element["text"])
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore
_part = PartType(text=element["text"]) # type: ignore
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
image_url = element["image_url"]["url"] # type: ignore
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
user_content.extend(_parts)
elif (
isinstance(messages[msg_i]["content"], str)
and len(messages[msg_i]["content"]) > 0
messages[msg_i]["content"] is not None
and isinstance(messages[msg_i]["content"], str)
and len(messages[msg_i]["content"]) > 0 # type: ignore
):
_part = PartType(text=messages[msg_i]["content"])
_part = PartType(text=messages[msg_i]["content"]) # type: ignore
user_content.append(_part)
msg_i += 1
@ -175,31 +181,34 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
messages[msg_i]["content"], list
):
_parts = []
for element in messages[msg_i]["content"]:
for element in messages[msg_i]["content"]: # type: ignore
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_part = PartType(text=element["text"]) # type: ignore
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
image_url = element["image_url"]["url"] # type: ignore
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
assistant_content.extend(_parts)
elif (
messages[msg_i].get("content", None) is not None
and isinstance(messages[msg_i]["content"], str)
and messages[msg_i]["content"]
):
assistant_text = messages[msg_i]["content"] # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(
messages[msg_i]["tool_calls"]
)
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
)
last_message_with_tool_calls = messages[msg_i]
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(PartType(text=assistant_text))
elif messages[msg_i].get("function_call") is not None:
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
)
msg_i += 1
@ -207,12 +216,16 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
if msg_i < len(messages) and (
messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function"
):
_part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls
messages[msg_i], last_message_with_tool_calls # type: ignore
)
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(

View file

@ -906,6 +906,7 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode):
"tools": tools,
"tool_choice": "required",
}
print(f"Model for call - {model}")
if sync_mode:
response = litellm.completion(**data)
else:
@ -2630,6 +2631,129 @@ async def test_partner_models_httpx_ai21():
print(f"response: {response}")
print("hidden params from response=", response._hidden_params)
assert response._hidden_params["response_cost"] > 0
def test_gemini_function_call_parameter_in_messages():
litellm.set_verbose = True
load_vertex_ai_credentials()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
tools = [
{
"type": "function",
"function": {
"name": "search",
"description": "Executes searches.",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "array",
"description": "A list of queries to search for.",
"items": {"type": "string"},
},
},
"required": ["queries"],
},
},
},
]
# Set up the messages
messages = [
{"role": "system", "content": """Use search for most queries."""},
{"role": "user", "content": """search for weather in boston (use `search`)"""},
{
"role": "assistant",
"content": None,
"function_call": {
"name": "search",
"arguments": '{"queries": ["weather in boston"]}',
},
},
{
"role": "function",
"name": "search",
"content": "The current weather in Boston is 22°F.",
},
]
client = HTTPHandler(concurrent_limit=1)
with patch.object(client, "post", new=MagicMock()) as mock_client:
try:
response_stream = completion(
model="vertex_ai/gemini-1.5-pro",
messages=messages,
tools=tools,
tool_choice="auto",
client=client,
)
except Exception as e:
print(e)
# mock_client.assert_any_call()
assert {
"contents": [
{
"role": "user",
"parts": [{"text": "search for weather in boston (use `search`)"}],
},
{
"role": "model",
"parts": [
{
"function_call": {
"name": "search",
"args": {
"fields": {
"key": "queries",
"value": {"list_value": ["weather in boston"]},
}
},
}
}
],
},
{
"parts": [
{
"function_response": {
"name": "search",
"response": {
"fields": {
"key": "content",
"value": {
"string_value": "The current weather in Boston is 22°F."
},
}
},
}
}
]
},
],
"system_instruction": {"parts": [{"text": "Use search for most queries."}]},
"tools": [
{
"function_declarations": [
{
"name": "search",
"description": "Executes searches.",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "array",
"description": "A list of queries to search for.",
"items": {"type": "string"},
}
},
"required": ["queries"],
},
}
]
}
],
"toolConfig": {"functionCallingConfig": {"mode": "AUTO"}},
"generationConfig": {},
} == mock_client.call_args.kwargs["json"]

View file

@ -364,8 +364,9 @@ class ChatCompletionUserMessage(TypedDict):
class ChatCompletionAssistantMessage(TypedDict, total=False):
role: Required[Literal["assistant"]]
content: Optional[str]
name: str
tool_calls: List[ChatCompletionAssistantToolCall]
name: Optional[str]
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
function_call: Optional[ChatCompletionToolCallFunctionChunk]
class ChatCompletionToolMessage(TypedDict):
@ -374,6 +375,12 @@ class ChatCompletionToolMessage(TypedDict):
tool_call_id: str
class ChatCompletionFunctionMessage(TypedDict):
role: Literal["function"]
content: Optional[str]
name: str
class ChatCompletionSystemMessage(TypedDict, total=False):
role: Required[Literal["system"]]
content: Required[Union[str, List]]
@ -385,6 +392,7 @@ AllMessageValues = Union[
ChatCompletionAssistantMessage,
ChatCompletionToolMessage,
ChatCompletionSystemMessage,
ChatCompletionFunctionMessage,
]