fix(vertex_ai.py): support passing in result of tool call to vertex

Fixes https://github.com/BerriAI/litellm/issues/3709
This commit is contained in:
Krrish Dholakia 2024-05-19 11:34:07 -07:00
parent 5d3fe52a08
commit a2c66ed4fb
5 changed files with 485 additions and 46 deletions

View file

@ -12,6 +12,7 @@ from typing import (
Sequence, Sequence,
) )
import litellm import litellm
import litellm.types
from litellm.types.completion import ( from litellm.types.completion import (
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam, ChatCompletionSystemMessageParam,
@ -20,9 +21,12 @@ from litellm.types.completion import (
ChatCompletionMessageToolCallParam, ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
) )
import litellm.types.llms
from litellm.types.llms.anthropic import * from litellm.types.llms.anthropic import *
import uuid import uuid
import litellm.types.llms.vertex_ai
def default_pt(messages): def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
@ -841,6 +845,175 @@ def anthropic_messages_pt_xml(messages: list):
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
def infer_protocol_value(
value: Any,
) -> Literal[
"string_value",
"number_value",
"bool_value",
"struct_value",
"list_value",
"null_value",
"unknown",
]:
if value is None:
return "null_value"
if isinstance(value, int) or isinstance(value, float):
return "number_value"
if isinstance(value, str):
return "string_value"
if isinstance(value, bool):
return "bool_value"
if isinstance(value, dict):
return "struct_value"
if isinstance(value, list):
return "list_value"
return "unknown"
def convert_to_gemini_tool_call_invoke(
tool_calls: list,
) -> List[litellm.types.llms.vertex_ai.PartType]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
content {
role: "model"
parts [
{
function_call {
name: "get_current_weather"
args {
fields {
key: "unit"
value {
string_value: "fahrenheit"
}
}
fields {
key: "predicted_temperature"
value {
number_value: 45
}
}
fields {
key: "location"
value {
string_value: "Boston, MA"
}
}
}
},
{
function_call {
name: "get_current_weather"
args {
fields {
key: "location"
value {
string_value: "San Francisco"
}
}
}
}
}
]
}
"""
"""
- json.load the arguments
- iterate through arguments -> create a FunctionCallArgs for each field
"""
try:
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
for tool in tool_calls:
if "function" in tool:
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
for k, v in arguments_dict.items():
inferred_protocol_value = infer_protocol_value(value=v)
_field = litellm.types.llms.vertex_ai.Field(
key=k, value={inferred_protocol_value: v}
)
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
fields=_field
)
function_call = litellm.types.llms.vertex_ai.FunctionCall(
name=name,
args=_fields,
)
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(function_call=function_call)
)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def convert_to_gemini_tool_call_result(
message: dict,
) -> litellm.types.llms.vertex_ai.PartType:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
content = message.get("content", "")
name = message.get("name", "")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
inferred_content_value = infer_protocol_value(value=content)
_field = litellm.types.llms.vertex_ai.Field(
key="content", value={inferred_content_value: content}
)
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
name=name, response=_function_call_args
)
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
return _part
def convert_to_anthropic_tool_result(message: dict) -> dict: def convert_to_anthropic_tool_result(message: dict) -> dict:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:

View file

@ -3,10 +3,15 @@ import json
from enum import Enum from enum import Enum
import requests # type: ignore import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List from typing import Callable, Optional, Union, List, Literal
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx, inspect # type: ignore import httpx, inspect # type: ignore
from litellm.types.llms.vertex_ai import *
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
class VertexAIError(Exception): class VertexAIError(Exception):
@ -283,6 +288,129 @@ def _load_image_from_url(image_url: str):
return Image.from_bytes(data=image_bytes) return Image.from_bytes(data=image_bytes)
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
if role == "user":
return "user"
else:
return "model"
def _process_gemini_image(image_url: str):
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
from vertexai.preview.generative_models import Part
if "gs://" in image_url:
# Case 1: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in image_url else "image/jpeg"
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
return google_clooud_part
elif "https:/" in image_url:
# Case 2: Images with direct links
image = _load_image_from_url(image_url)
return image
elif ".mp4" in image_url and "gs://" in image_url:
# Case 3: Videos with Cloud Storage URIs
part_mime = "video/mp4"
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
return google_clooud_part
elif "base64" in image_url:
# Case 4: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
return processed_image
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
- Parts must be iterable
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
user_message_types = {"user", "system"}
contents: List[ContentType] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list):
_parts: List[PartType] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
user_content.extend(_parts)
else:
_part = PartType(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(ContentType(role="user", parts=user_content))
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(PartType(text=assistant_text))
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
)
msg_i += 1
if assistant_content:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
if messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i])
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _gemini_vision_convert_messages(messages: list): def _gemini_vision_convert_messages(messages: list):
""" """
Converts given messages for GPT-4 Vision to Gemini format. Converts given messages for GPT-4 Vision to Gemini format.
@ -574,11 +702,9 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages) content = _gemini_convert_messages_text(messages=messages)
content = [prompt] + images
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
if stream == True: if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -590,7 +716,7 @@ def completion(
) )
model_response = llm_model.generate_content( model_response = llm_model.generate_content(
contents=content, contents={"content": content},
generation_config=optional_params, generation_config=optional_params,
safety_settings=safety_settings, safety_settings=safety_settings,
stream=True, stream=True,

View file

@ -1,15 +1,21 @@
============================= test session starts ============================== ============================= test session starts ==============================
platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0 platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0
rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests rootdir: /Users/krrishdholakia/Documents/litellm
plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1 configfile: pyproject.toml
plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0
asyncio: mode=Mode.STRICT asyncio: mode=Mode.STRICT
collected 1 item collected 1 item
test_router_timeout.py . [100%] test_amazing_vertex_completion.py Success: model=gemini-1.5-pro-preview-0514 in model_cost_map
prompt_tokens=88; completion_tokens=31
Returned custom cost for model=gemini-1.5-pro-preview-0514 - prompt_tokens_cost_usd_dollar: 5.5e-05, completion_tokens_cost_usd_dollar: 5.8125e-05
final cost: 0.000113125; prompt_tokens_cost_usd_dollar: 5.5e-05; completion_tokens_cost_usd_dollar: 5.8125e-05
success callbacks: []
. [100%]
=============================== warnings summary =============================== =============================== warnings summary ===============================
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings ../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../proxy/_types.py:255 ../proxy/_types.py:255
@ -68,9 +74,5 @@ test_router_timeout.py . [100%]
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
test_router_timeout.py::test_router_timeouts_bedrock
/opt/homebrew/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
warnings.warn(message, DeprecationWarning)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 40 warnings in 0.99s ======================== ======================== 1 passed, 39 warnings in 2.92s ========================

View file

@ -16,6 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
import json import json
import os import os
import tempfile import tempfile
from litellm.llms.vertex_ai import _gemini_convert_messages_text
litellm.num_retries = 3 litellm.num_retries = 3
litellm.cache = None litellm.cache = None
@ -98,7 +99,7 @@ def load_vertex_ai_credentials():
@pytest.mark.asyncio @pytest.mark.asyncio
async def get_response(): async def test_get_response():
load_vertex_ai_credentials() load_vertex_ai_credentials()
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n' prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try: try:
@ -589,35 +590,73 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_function_calling(sync_mode): async def test_gemini_pro_function_calling(sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
data = { litellm.set_verbose = True
"model": "vertex_ai/gemini-pro",
"messages": [ messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{ {
"role": "user", "role": "user",
"content": "Call the submit_cities function with San Francisco and New York", "content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
} }
], ],
"tools": [ },
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"name": "get_weather",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
tools = [
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "submit_cities", "name": "get_weather",
"description": "Submits a list of cities", "description": "Get the current weather in a given location",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"cities": {"type": "array", "items": {"type": "string"}} "location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
}, },
"required": ["cities"], "required": ["location"],
}, },
}, },
} }
], ]
data = {
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
"messages": messages,
"tools": tools,
} }
if sync_mode: if sync_mode:
response = litellm.completion(**data) response = litellm.completion(**data)
@ -712,7 +751,7 @@ async def test_gemini_pro_async_function_calling():
"type": "function", "type": "function",
"function": { "function": {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather in a given location. Response with a prediction on temperature as well.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -724,8 +763,9 @@ async def test_gemini_pro_async_function_calling():
"type": "string", "type": "string",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
"predicted_temperature": {"type": "integer"},
}, },
"required": ["location"], "required": ["location", "predicted_temperature"],
}, },
}, },
} }
@ -733,7 +773,7 @@ async def test_gemini_pro_async_function_calling():
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "What's the weather like in Boston today in fahrenheit?", "content": "What's the weather like in Boston today in fahrenheit, give me a prediction?",
} }
] ]
completion = await litellm.acompletion( completion = await litellm.acompletion(
@ -742,8 +782,10 @@ async def test_gemini_pro_async_function_calling():
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1 assert len(completion.choices[0].message.tool_calls) == 1
except litellm.APIError as e:
pass raise Exception
# except litellm.APIError as e:
# pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except Exception as e: except Exception as e:
@ -893,3 +935,46 @@ async def test_vertexai_aembedding():
# traceback.print_exc() # traceback.print_exc()
# raise e # raise e
# test_gemini_pro_vision_async() # test_gemini_pro_vision_async()
def test_prompt_factory():
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
}
],
},
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"name": "get_weather",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
translated_messages = _gemini_convert_messages_text(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
raise Exception

View file

@ -0,0 +1,53 @@
from typing import TypedDict, Any, Union, Optional, List, Literal
import json
from typing_extensions import (
Self,
Protocol,
TypeGuard,
override,
get_origin,
runtime_checkable,
Required,
)
class Field(TypedDict):
key: str
value: dict[str, Any]
class FunctionCallArgs(TypedDict):
fields: Field
class FunctionResponse(TypedDict):
name: str
response: FunctionCallArgs
class FunctionCall(TypedDict):
name: str
args: FunctionCallArgs
class FileDataType(TypedDict):
mime_type: str
file_uri: str # the cloud storage uri of storing this file
class BlobType(TypedDict):
mime_type: Required[str]
data: Required[bytes]
class PartType(TypedDict, total=False):
text: str
inline_data: BlobType
file_data: FileDataType
function_call: FunctionCall
function_response: FunctionResponse
class ContentType(TypedDict, total=False):
role: Literal["user", "model"]
parts: Required[List[PartType]]