From 6fb82aaf75bfb12f4a8501261b8578bce93d9b6e Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 31 Aug 2024 17:58:10 -0700 Subject: [PATCH] Minor LiteLLM Fixes and Improvements (#5456) * fix(utils.py): support 'drop_params' for embedding requests Fixes https://github.com/BerriAI/litellm/issues/5444 * feat(vertex_ai_non_gemini.py): support function param in messages * test: skip test - model end of life * fix(vertex_ai_non_gemini.py): fix gemini history parsing --- litellm/llms/prompt_templates/factory.py | 112 ++++++++++----- .../vertex_and_google_ai_studio_gemini.py | 2 +- .../vertex_ai_non_gemini.py | 59 ++++---- .../tests/test_amazing_vertex_completion.py | 128 +++++++++++++++++- litellm/types/llms/openai.py | 12 +- 5 files changed, 248 insertions(+), 65 deletions(-) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 2396cd26c..a1894d87f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -26,6 +26,12 @@ from litellm.types.completion import ( ) from litellm.types.llms.anthropic import * from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock +from litellm.types.llms.openai import ( + ChatCompletionAssistantMessage, + ChatCompletionFunctionMessage, + ChatCompletionToolCallFunctionChunk, + ChatCompletionToolMessage, +) from litellm.types.utils import GenericImageParsingChunk @@ -964,8 +970,28 @@ def infer_protocol_value( return "unknown" +def _gemini_tool_call_invoke_helper( + function_call_params: ChatCompletionToolCallFunctionChunk, +) -> Optional[litellm.types.llms.vertex_ai.FunctionCall]: + name = function_call_params.get("name", "") or "" + arguments = function_call_params.get("arguments", "") + arguments_dict = json.loads(arguments) + function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = None + for k, v in arguments_dict.items(): + inferred_protocol_value = infer_protocol_value(value=v) + _field = litellm.types.llms.vertex_ai.Field( + key=k, value={inferred_protocol_value: v} + ) + _fields = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field) + function_call = litellm.types.llms.vertex_ai.FunctionCall( + name=name, + args=_fields, + ) + return function_call + + def convert_to_gemini_tool_call_invoke( - tool_calls: list, + message: ChatCompletionAssistantMessage, ) -> List[litellm.types.llms.vertex_ai.PartType]: """ OpenAI tool invokes: @@ -1036,49 +1062,55 @@ def convert_to_gemini_tool_call_invoke( """ try: _parts_list: List[litellm.types.llms.vertex_ai.PartType] = [] - for tool in tool_calls: - if "function" in tool: - name = tool["function"].get("name", "") - arguments = tool["function"].get("arguments", "") - arguments_dict = json.loads(arguments) - function_call: Optional[litellm.types.llms.vertex_ai.FunctionCall] = ( - None + tool_calls = message.get("tool_calls", None) + function_call = message.get("function_call", None) + if tool_calls is not None: + for tool in tool_calls: + if "function" in tool: + gemini_function_call: Optional[ + litellm.types.llms.vertex_ai.FunctionCall + ] = _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] + ) + if gemini_function_call is not None: + _parts_list.append( + litellm.types.llms.vertex_ai.PartType( + function_call=gemini_function_call + ) + ) + else: # don't silently drop params. Make it clear to user what's happening. + raise Exception( + "function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format( + tool + ) + ) + elif function_call is not None: + gemini_function_call = _gemini_tool_call_invoke_helper( + function_call_params=function_call + ) + if gemini_function_call is not None: + _parts_list.append( + litellm.types.llms.vertex_ai.PartType( + function_call=gemini_function_call + ) ) - for k, v in arguments_dict.items(): - inferred_protocol_value = infer_protocol_value(value=v) - _field = litellm.types.llms.vertex_ai.Field( - key=k, value={inferred_protocol_value: v} - ) - _fields = litellm.types.llms.vertex_ai.FunctionCallArgs( - fields=_field - ) - function_call = litellm.types.llms.vertex_ai.FunctionCall( - name=name, - args=_fields, - ) - if function_call is not None: - _parts_list.append( - litellm.types.llms.vertex_ai.PartType( - function_call=function_call - ) - ) - else: # don't silently drop params. Make it clear to user what's happening. - raise Exception( - "function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format( - tool - ) + else: # don't silently drop params. Make it clear to user what's happening. + raise Exception( + "function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format( + tool ) + ) return _parts_list except Exception as e: raise Exception( "Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format( - tool_calls, str(e) + message, str(e) ) ) def convert_to_gemini_tool_call_result( - message: dict, + message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage], last_message_with_tool_calls: Optional[dict], ) -> litellm.types.llms.vertex_ai.PartType: """ @@ -1098,7 +1130,7 @@ def convert_to_gemini_tool_call_result( } """ content = message.get("content", "") - name = "" + name: Optional[str] = message.get("name", "") # type: ignore # Recover name from last message with tool calls if last_message_with_tool_calls: @@ -1114,7 +1146,11 @@ def convert_to_gemini_tool_call_result( name = tool.get("function", {}).get("name", "") if not name: - raise Exception("Missing corresponding tool call for tool response message") + raise Exception( + "Missing corresponding tool call for tool response message. Received - message={}, last_message_with_tool_calls={}".format( + message, last_message_with_tool_calls + ) + ) # We can't determine from openai message format whether it's a successful or # error call result so default to the successful result template @@ -1127,7 +1163,7 @@ def convert_to_gemini_tool_call_result( _function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field) _function_response = litellm.types.llms.vertex_ai.FunctionResponse( - name=name, response=_function_call_args + name=name, response=_function_call_args # type: ignore ) _part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response) @@ -1782,7 +1818,9 @@ def cohere_messages_pt_v2( assistant_tool_calls: List[ToolCallObject] = [] ## MERGE CONSECUTIVE ASSISTANT CONTENT ## while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if messages[msg_i].get("content", None) is not None and isinstance(messages[msg_i]["content"], list): + if messages[msg_i].get("content", None) is not None and isinstance( + messages[msg_i]["content"], list + ): for m in messages[msg_i]["content"]: if m.get("type", "") == "text": assistant_content += m["text"] diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 87ebdb56b..ab644485d 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -1433,7 +1433,7 @@ class VertexLLM(BaseLLM): }, ) - if stream is not None and stream is True: + if stream is True: request_data_str = json.dumps(data) streaming_response = CustomStreamWrapper( completion_stream=None, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py index 2a250864a..8fb7aa2ee 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py @@ -25,6 +25,7 @@ from litellm.types.files import ( is_gemini_1_5_accepted_file_type, is_video_file_type, ) +from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.vertex_ai import * from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -123,7 +124,9 @@ def _process_gemini_image(image_url: str) -> PartType: raise e -def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: +def _gemini_convert_messages_with_history( + messages: List[AllMessageValues], +) -> List[ContentType]: """ Converts given messages from OpenAI format to Gemini format @@ -145,23 +148,26 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: while ( msg_i < len(messages) and messages[msg_i]["role"] in user_message_types ): - if isinstance(messages[msg_i]["content"], list): + if messages[msg_i]["content"] is not None and isinstance( + messages[msg_i]["content"], list + ): _parts: List[PartType] = [] - for element in messages[msg_i]["content"]: + for element in messages[msg_i]["content"]: # type: ignore if isinstance(element, dict): - if element["type"] == "text" and len(element["text"]) > 0: - _part = PartType(text=element["text"]) + if element["type"] == "text" and len(element["text"]) > 0: # type: ignore + _part = PartType(text=element["text"]) # type: ignore _parts.append(_part) elif element["type"] == "image_url": - image_url = element["image_url"]["url"] + image_url = element["image_url"]["url"] # type: ignore _part = _process_gemini_image(image_url=image_url) _parts.append(_part) # type: ignore user_content.extend(_parts) elif ( - isinstance(messages[msg_i]["content"], str) - and len(messages[msg_i]["content"]) > 0 + messages[msg_i]["content"] is not None + and isinstance(messages[msg_i]["content"], str) + and len(messages[msg_i]["content"]) > 0 # type: ignore ): - _part = PartType(text=messages[msg_i]["content"]) + _part = PartType(text=messages[msg_i]["content"]) # type: ignore user_content.append(_part) msg_i += 1 @@ -175,31 +181,34 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: messages[msg_i]["content"], list ): _parts = [] - for element in messages[msg_i]["content"]: + for element in messages[msg_i]["content"]: # type: ignore if isinstance(element, dict): if element["type"] == "text": - _part = PartType(text=element["text"]) + _part = PartType(text=element["text"]) # type: ignore _parts.append(_part) elif element["type"] == "image_url": - image_url = element["image_url"]["url"] + image_url = element["image_url"]["url"] # type: ignore _part = _process_gemini_image(image_url=image_url) _parts.append(_part) # type: ignore assistant_content.extend(_parts) + elif ( + messages[msg_i].get("content", None) is not None + and isinstance(messages[msg_i]["content"], str) + and messages[msg_i]["content"] + ): + assistant_text = messages[msg_i]["content"] # either string or none + assistant_content.append(PartType(text=assistant_text)) # type: ignore elif messages[msg_i].get( "tool_calls", [] ): # support assistant tool invoke conversion assistant_content.extend( - convert_to_gemini_tool_call_invoke( - messages[msg_i]["tool_calls"] - ) + convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore ) last_message_with_tool_calls = messages[msg_i] - else: - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content.append(PartType(text=assistant_text)) + elif messages[msg_i].get("function_call") is not None: + assistant_content.extend( + convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore + ) msg_i += 1 @@ -207,12 +216,16 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: contents.append(ContentType(role="model", parts=assistant_content)) ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + if msg_i < len(messages) and ( + messages[msg_i]["role"] == "tool" + or messages[msg_i]["role"] == "function" + ): _part = convert_to_gemini_tool_call_result( - messages[msg_i], last_message_with_tool_calls + messages[msg_i], last_message_with_tool_calls # type: ignore ) contents.append(ContentType(parts=[_part])) # type: ignore msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops raise Exception( "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 4b659944a..5ff1e1046 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -906,6 +906,7 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): "tools": tools, "tool_choice": "required", } + print(f"Model for call - {model}") if sync_mode: response = litellm.completion(**data) else: @@ -2630,6 +2631,129 @@ async def test_partner_models_httpx_ai21(): print(f"response: {response}") - print("hidden params from response=", response._hidden_params) - assert response._hidden_params["response_cost"] > 0 +def test_gemini_function_call_parameter_in_messages(): + litellm.set_verbose = True + load_vertex_ai_credentials() + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Executes searches.", + "parameters": { + "type": "object", + "properties": { + "queries": { + "type": "array", + "description": "A list of queries to search for.", + "items": {"type": "string"}, + }, + }, + "required": ["queries"], + }, + }, + }, + ] + + # Set up the messages + messages = [ + {"role": "system", "content": """Use search for most queries."""}, + {"role": "user", "content": """search for weather in boston (use `search`)"""}, + { + "role": "assistant", + "content": None, + "function_call": { + "name": "search", + "arguments": '{"queries": ["weather in boston"]}', + }, + }, + { + "role": "function", + "name": "search", + "content": "The current weather in Boston is 22°F.", + }, + ] + + client = HTTPHandler(concurrent_limit=1) + + with patch.object(client, "post", new=MagicMock()) as mock_client: + try: + response_stream = completion( + model="vertex_ai/gemini-1.5-pro", + messages=messages, + tools=tools, + tool_choice="auto", + client=client, + ) + except Exception as e: + print(e) + + # mock_client.assert_any_call() + assert { + "contents": [ + { + "role": "user", + "parts": [{"text": "search for weather in boston (use `search`)"}], + }, + { + "role": "model", + "parts": [ + { + "function_call": { + "name": "search", + "args": { + "fields": { + "key": "queries", + "value": {"list_value": ["weather in boston"]}, + } + }, + } + } + ], + }, + { + "parts": [ + { + "function_response": { + "name": "search", + "response": { + "fields": { + "key": "content", + "value": { + "string_value": "The current weather in Boston is 22°F." + }, + } + }, + } + } + ] + }, + ], + "system_instruction": {"parts": [{"text": "Use search for most queries."}]}, + "tools": [ + { + "function_declarations": [ + { + "name": "search", + "description": "Executes searches.", + "parameters": { + "type": "object", + "properties": { + "queries": { + "type": "array", + "description": "A list of queries to search for.", + "items": {"type": "string"}, + } + }, + "required": ["queries"], + }, + } + ] + } + ], + "toolConfig": {"functionCallingConfig": {"mode": "AUTO"}}, + "generationConfig": {}, + } == mock_client.call_args.kwargs["json"] diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 138441a7e..a3d8b756f 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -364,8 +364,9 @@ class ChatCompletionUserMessage(TypedDict): class ChatCompletionAssistantMessage(TypedDict, total=False): role: Required[Literal["assistant"]] content: Optional[str] - name: str - tool_calls: List[ChatCompletionAssistantToolCall] + name: Optional[str] + tool_calls: Optional[List[ChatCompletionAssistantToolCall]] + function_call: Optional[ChatCompletionToolCallFunctionChunk] class ChatCompletionToolMessage(TypedDict): @@ -374,6 +375,12 @@ class ChatCompletionToolMessage(TypedDict): tool_call_id: str +class ChatCompletionFunctionMessage(TypedDict): + role: Literal["function"] + content: Optional[str] + name: str + + class ChatCompletionSystemMessage(TypedDict, total=False): role: Required[Literal["system"]] content: Required[Union[str, List]] @@ -385,6 +392,7 @@ AllMessageValues = Union[ ChatCompletionAssistantMessage, ChatCompletionToolMessage, ChatCompletionSystemMessage, + ChatCompletionFunctionMessage, ]