mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(vertex_ai.py): support passing in result of tool call to vertex
Fixes https://github.com/BerriAI/litellm/issues/3709
This commit is contained in:
parent
5d3fe52a08
commit
a2c66ed4fb
5 changed files with 485 additions and 46 deletions
|
@ -12,6 +12,7 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
|
import litellm.types
|
||||||
from litellm.types.completion import (
|
from litellm.types.completion import (
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
|
@ -20,9 +21,12 @@ from litellm.types.completion import (
|
||||||
ChatCompletionMessageToolCallParam,
|
ChatCompletionMessageToolCallParam,
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
)
|
)
|
||||||
|
import litellm.types.llms
|
||||||
from litellm.types.llms.anthropic import *
|
from litellm.types.llms.anthropic import *
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import litellm.types.llms.vertex_ai
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
return " ".join(message["content"] for message in messages)
|
return " ".join(message["content"] for message in messages)
|
||||||
|
@ -841,6 +845,175 @@ def anthropic_messages_pt_xml(messages: list):
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def infer_protocol_value(
|
||||||
|
value: Any,
|
||||||
|
) -> Literal[
|
||||||
|
"string_value",
|
||||||
|
"number_value",
|
||||||
|
"bool_value",
|
||||||
|
"struct_value",
|
||||||
|
"list_value",
|
||||||
|
"null_value",
|
||||||
|
"unknown",
|
||||||
|
]:
|
||||||
|
if value is None:
|
||||||
|
return "null_value"
|
||||||
|
if isinstance(value, int) or isinstance(value, float):
|
||||||
|
return "number_value"
|
||||||
|
if isinstance(value, str):
|
||||||
|
return "string_value"
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "bool_value"
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return "struct_value"
|
||||||
|
if isinstance(value, list):
|
||||||
|
return "list_value"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_gemini_tool_call_invoke(
|
||||||
|
tool_calls: list,
|
||||||
|
) -> List[litellm.types.llms.vertex_ai.PartType]:
|
||||||
|
"""
|
||||||
|
OpenAI tool invokes:
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
|
||||||
|
content {
|
||||||
|
role: "model"
|
||||||
|
parts [
|
||||||
|
{
|
||||||
|
function_call {
|
||||||
|
name: "get_current_weather"
|
||||||
|
args {
|
||||||
|
fields {
|
||||||
|
key: "unit"
|
||||||
|
value {
|
||||||
|
string_value: "fahrenheit"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fields {
|
||||||
|
key: "predicted_temperature"
|
||||||
|
value {
|
||||||
|
number_value: 45
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fields {
|
||||||
|
key: "location"
|
||||||
|
value {
|
||||||
|
string_value: "Boston, MA"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
function_call {
|
||||||
|
name: "get_current_weather"
|
||||||
|
args {
|
||||||
|
fields {
|
||||||
|
key: "location"
|
||||||
|
value {
|
||||||
|
string_value: "San Francisco"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
- json.load the arguments
|
||||||
|
- iterate through arguments -> create a FunctionCallArgs for each field
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
|
||||||
|
for tool in tool_calls:
|
||||||
|
if "function" in tool:
|
||||||
|
name = tool["function"].get("name", "")
|
||||||
|
arguments = tool["function"].get("arguments", "")
|
||||||
|
arguments_dict = json.loads(arguments)
|
||||||
|
for k, v in arguments_dict.items():
|
||||||
|
inferred_protocol_value = infer_protocol_value(value=v)
|
||||||
|
_field = litellm.types.llms.vertex_ai.Field(
|
||||||
|
key=k, value={inferred_protocol_value: v}
|
||||||
|
)
|
||||||
|
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
|
||||||
|
fields=_field
|
||||||
|
)
|
||||||
|
function_call = litellm.types.llms.vertex_ai.FunctionCall(
|
||||||
|
name=name,
|
||||||
|
args=_fields,
|
||||||
|
)
|
||||||
|
_parts_list.append(
|
||||||
|
litellm.types.llms.vertex_ai.PartType(function_call=function_call)
|
||||||
|
)
|
||||||
|
return _parts_list
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
|
||||||
|
tool_calls, str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_gemini_tool_call_result(
|
||||||
|
message: dict,
|
||||||
|
) -> litellm.types.llms.vertex_ai.PartType:
|
||||||
|
"""
|
||||||
|
OpenAI message with a tool result looks like:
|
||||||
|
{
|
||||||
|
"tool_call_id": "tool_1",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
},
|
||||||
|
|
||||||
|
OpenAI message with a function call result looks like:
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
content = message.get("content", "")
|
||||||
|
name = message.get("name", "")
|
||||||
|
|
||||||
|
# We can't determine from openai message format whether it's a successful or
|
||||||
|
# error call result so default to the successful result template
|
||||||
|
inferred_content_value = infer_protocol_value(value=content)
|
||||||
|
|
||||||
|
_field = litellm.types.llms.vertex_ai.Field(
|
||||||
|
key="content", value={inferred_content_value: content}
|
||||||
|
)
|
||||||
|
|
||||||
|
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
||||||
|
|
||||||
|
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
|
||||||
|
name=name, response=_function_call_args
|
||||||
|
)
|
||||||
|
|
||||||
|
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
|
||||||
|
|
||||||
|
return _part
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_tool_result(message: dict) -> dict:
|
def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
|
|
@ -3,10 +3,15 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List
|
from typing import Callable, Optional, Union, List, Literal
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx, inspect # type: ignore
|
import httpx, inspect # type: ignore
|
||||||
|
from litellm.types.llms.vertex_ai import *
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_to_gemini_tool_call_result,
|
||||||
|
convert_to_gemini_tool_call_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -283,6 +288,129 @@ def _load_image_from_url(image_url: str):
|
||||||
return Image.from_bytes(data=image_bytes)
|
return Image.from_bytes(data=image_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
||||||
|
if role == "user":
|
||||||
|
return "user"
|
||||||
|
else:
|
||||||
|
return "model"
|
||||||
|
|
||||||
|
|
||||||
|
def _process_gemini_image(image_url: str):
|
||||||
|
try:
|
||||||
|
import vertexai
|
||||||
|
except:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||||
|
)
|
||||||
|
from vertexai.preview.generative_models import Part
|
||||||
|
|
||||||
|
if "gs://" in image_url:
|
||||||
|
# Case 1: Images with Cloud Storage URIs
|
||||||
|
# The supported MIME types for images include image/png and image/jpeg.
|
||||||
|
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||||
|
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||||
|
return google_clooud_part
|
||||||
|
elif "https:/" in image_url:
|
||||||
|
# Case 2: Images with direct links
|
||||||
|
image = _load_image_from_url(image_url)
|
||||||
|
return image
|
||||||
|
elif ".mp4" in image_url and "gs://" in image_url:
|
||||||
|
# Case 3: Videos with Cloud Storage URIs
|
||||||
|
part_mime = "video/mp4"
|
||||||
|
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||||
|
return google_clooud_part
|
||||||
|
elif "base64" in image_url:
|
||||||
|
# Case 4: Images with base64 encoding
|
||||||
|
import base64, re
|
||||||
|
|
||||||
|
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||||
|
image_metadata, img_without_base_64 = image_url.split(",")
|
||||||
|
|
||||||
|
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||||
|
# Extract MIME type using regular expression
|
||||||
|
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||||
|
|
||||||
|
if mime_type_match:
|
||||||
|
mime_type = mime_type_match.group(1)
|
||||||
|
else:
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
decoded_img = base64.b64decode(img_without_base_64)
|
||||||
|
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
|
||||||
|
"""
|
||||||
|
Converts given messages from OpenAI format to Gemini format
|
||||||
|
|
||||||
|
- Parts must be iterable
|
||||||
|
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||||
|
- Please ensure that function response turn comes immediately after a function call turn
|
||||||
|
"""
|
||||||
|
user_message_types = {"user", "system"}
|
||||||
|
contents: List[ContentType] = []
|
||||||
|
|
||||||
|
msg_i = 0
|
||||||
|
while msg_i < len(messages):
|
||||||
|
user_content: List[PartType] = []
|
||||||
|
init_msg_i = msg_i
|
||||||
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||||
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
_parts: List[PartType] = []
|
||||||
|
for element in messages[msg_i]["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
_part = PartType(text=element["text"])
|
||||||
|
_parts.append(_part)
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
_part = _process_gemini_image(image_url=image_url)
|
||||||
|
_parts.append(_part) # type: ignore
|
||||||
|
user_content.extend(_parts)
|
||||||
|
else:
|
||||||
|
_part = PartType(text=messages[msg_i]["content"])
|
||||||
|
user_content.append(_part)
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if user_content:
|
||||||
|
contents.append(ContentType(role="user", parts=user_content))
|
||||||
|
assistant_content = []
|
||||||
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
assistant_text = (
|
||||||
|
messages[msg_i].get("content") or ""
|
||||||
|
) # either string or none
|
||||||
|
if assistant_text:
|
||||||
|
assistant_content.append(PartType(text=assistant_text))
|
||||||
|
if messages[msg_i].get(
|
||||||
|
"tool_calls", []
|
||||||
|
): # support assistant tool invoke convertion
|
||||||
|
assistant_content.extend(
|
||||||
|
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||||
|
)
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if assistant_content:
|
||||||
|
contents.append(ContentType(role="model", parts=assistant_content))
|
||||||
|
|
||||||
|
## APPEND TOOL CALL MESSAGES ##
|
||||||
|
if messages[msg_i]["role"] == "tool":
|
||||||
|
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||||
|
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||||
|
msg_i += 1
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
messages[msg_i]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
def _gemini_vision_convert_messages(messages: list):
|
def _gemini_vision_convert_messages(messages: list):
|
||||||
"""
|
"""
|
||||||
Converts given messages for GPT-4 Vision to Gemini format.
|
Converts given messages for GPT-4 Vision to Gemini format.
|
||||||
|
@ -574,11 +702,9 @@ def completion(
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
content = _gemini_convert_messages_text(messages=messages)
|
||||||
content = [prompt] + images
|
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
if stream == True:
|
if stream == True:
|
||||||
|
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -590,7 +716,7 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response = llm_model.generate_content(
|
model_response = llm_model.generate_content(
|
||||||
contents=content,
|
contents={"content": content},
|
||||||
generation_config=optional_params,
|
generation_config=optional_params,
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -1,15 +1,21 @@
|
||||||
============================= test session starts ==============================
|
============================= test session starts ==============================
|
||||||
platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0
|
platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0
|
||||||
rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests
|
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||||
plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1
|
configfile: pyproject.toml
|
||||||
|
plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0
|
||||||
asyncio: mode=Mode.STRICT
|
asyncio: mode=Mode.STRICT
|
||||||
collected 1 item
|
collected 1 item
|
||||||
|
|
||||||
test_router_timeout.py . [100%]
|
test_amazing_vertex_completion.py Success: model=gemini-1.5-pro-preview-0514 in model_cost_map
|
||||||
|
prompt_tokens=88; completion_tokens=31
|
||||||
|
Returned custom cost for model=gemini-1.5-pro-preview-0514 - prompt_tokens_cost_usd_dollar: 5.5e-05, completion_tokens_cost_usd_dollar: 5.8125e-05
|
||||||
|
final cost: 0.000113125; prompt_tokens_cost_usd_dollar: 5.5e-05; completion_tokens_cost_usd_dollar: 5.8125e-05
|
||||||
|
success callbacks: []
|
||||||
|
. [100%]
|
||||||
|
|
||||||
=============================== warnings summary ===============================
|
=============================== warnings summary ===============================
|
||||||
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings
|
../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings
|
||||||
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
||||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||||
|
|
||||||
../proxy/_types.py:255
|
../proxy/_types.py:255
|
||||||
|
@ -68,9 +74,5 @@ test_router_timeout.py . [100%]
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||||
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
||||||
|
|
||||||
test_router_timeout.py::test_router_timeouts_bedrock
|
|
||||||
/opt/homebrew/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
|
||||||
warnings.warn(message, DeprecationWarning)
|
|
||||||
|
|
||||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||||
======================== 1 passed, 40 warnings in 0.99s ========================
|
======================== 1 passed, 39 warnings in 2.92s ========================
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from litellm.llms.vertex_ai import _gemini_convert_messages_text
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -98,7 +99,7 @@ def load_vertex_ai_credentials():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def get_response():
|
async def test_get_response():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||||
try:
|
try:
|
||||||
|
@ -589,35 +590,73 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_function_calling(sync_mode):
|
async def test_gemini_pro_function_calling(sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
data = {
|
litellm.set_verbose = True
|
||||||
"model": "vertex_ai/gemini-pro",
|
|
||||||
"messages": [
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Call the submit_cities function with San Francisco and New York",
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
# Assistant replies with a tool call
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"index": 0,
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location":"San Francisco, CA"}',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tools": [
|
},
|
||||||
|
# The result of the tool call is added to the history
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": "27 degrees celsius and clear in San Francisco, CA",
|
||||||
|
},
|
||||||
|
# Now the assistant can reply with the result of the tool call.
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "submit_cities",
|
"name": "get_weather",
|
||||||
"description": "Submits a list of cities",
|
"description": "Get the current weather in a given location",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"cities": {"type": "array", "items": {"type": "string"}}
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["cities"],
|
"required": ["location"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
}
|
}
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = litellm.completion(**data)
|
response = litellm.completion(**data)
|
||||||
|
@ -712,7 +751,7 @@ async def test_gemini_pro_async_function_calling():
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"description": "Get the current weather in a given location",
|
"description": "Get the current weather in a given location. Response with a prediction on temperature as well.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -724,8 +763,9 @@ async def test_gemini_pro_async_function_calling():
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["celsius", "fahrenheit"],
|
"enum": ["celsius", "fahrenheit"],
|
||||||
},
|
},
|
||||||
|
"predicted_temperature": {"type": "integer"},
|
||||||
},
|
},
|
||||||
"required": ["location"],
|
"required": ["location", "predicted_temperature"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -733,7 +773,7 @@ async def test_gemini_pro_async_function_calling():
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "What's the weather like in Boston today in fahrenheit?",
|
"content": "What's the weather like in Boston today in fahrenheit, give me a prediction?",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
completion = await litellm.acompletion(
|
completion = await litellm.acompletion(
|
||||||
|
@ -742,8 +782,10 @@ async def test_gemini_pro_async_function_calling():
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
assert completion.choices[0].message.content is None
|
assert completion.choices[0].message.content is None
|
||||||
assert len(completion.choices[0].message.tool_calls) == 1
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
except litellm.APIError as e:
|
|
||||||
pass
|
raise Exception
|
||||||
|
# except litellm.APIError as e:
|
||||||
|
# pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -893,3 +935,46 @@ async def test_vertexai_aembedding():
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
# raise e
|
# raise e
|
||||||
# test_gemini_pro_vision_async()
|
# test_gemini_pro_vision_async()
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_factory():
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
# Assistant replies with a tool call
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"index": 0,
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location":"San Francisco, CA"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# The result of the tool call is added to the history
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": "27 degrees celsius and clear in San Francisco, CA",
|
||||||
|
},
|
||||||
|
# Now the assistant can reply with the result of the tool call.
|
||||||
|
]
|
||||||
|
|
||||||
|
translated_messages = _gemini_convert_messages_text(messages=messages)
|
||||||
|
|
||||||
|
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
|
||||||
|
raise Exception
|
||||||
|
|
53
litellm/types/llms/vertex_ai.py
Normal file
53
litellm/types/llms/vertex_ai.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import TypedDict, Any, Union, Optional, List, Literal
|
||||||
|
import json
|
||||||
|
from typing_extensions import (
|
||||||
|
Self,
|
||||||
|
Protocol,
|
||||||
|
TypeGuard,
|
||||||
|
override,
|
||||||
|
get_origin,
|
||||||
|
runtime_checkable,
|
||||||
|
Required,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Field(TypedDict):
|
||||||
|
key: str
|
||||||
|
value: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallArgs(TypedDict):
|
||||||
|
fields: Field
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionResponse(TypedDict):
|
||||||
|
name: str
|
||||||
|
response: FunctionCallArgs
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(TypedDict):
|
||||||
|
name: str
|
||||||
|
args: FunctionCallArgs
|
||||||
|
|
||||||
|
|
||||||
|
class FileDataType(TypedDict):
|
||||||
|
mime_type: str
|
||||||
|
file_uri: str # the cloud storage uri of storing this file
|
||||||
|
|
||||||
|
|
||||||
|
class BlobType(TypedDict):
|
||||||
|
mime_type: Required[str]
|
||||||
|
data: Required[bytes]
|
||||||
|
|
||||||
|
|
||||||
|
class PartType(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
inline_data: BlobType
|
||||||
|
file_data: FileDataType
|
||||||
|
function_call: FunctionCall
|
||||||
|
function_response: FunctionResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(TypedDict, total=False):
|
||||||
|
role: Literal["user", "model"]
|
||||||
|
parts: Required[List[PartType]]
|
Loading…
Add table
Add a link
Reference in a new issue