mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(vertex_ai.py): support tool call list response async completion
This commit is contained in:
parent
7f6e933372
commit
04ae285001
2 changed files with 28 additions and 11 deletions
|
@ -867,6 +867,8 @@ async def async_completion(
|
||||||
Add support for acompletion calls for gemini-pro
|
Add support for acompletion calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
import proto # type: ignore
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
@ -901,9 +903,21 @@ async def async_completion(
|
||||||
):
|
):
|
||||||
function_call = response.candidates[0].content.parts[0].function_call
|
function_call = response.candidates[0].content.parts[0].function_call
|
||||||
args_dict = {}
|
args_dict = {}
|
||||||
for k, v in function_call.args.items():
|
|
||||||
args_dict[k] = v
|
# Check if it's a RepeatedComposite instance
|
||||||
args_str = json.dumps(args_dict)
|
for key, val in function_call.args.items():
|
||||||
|
if isinstance(
|
||||||
|
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||||
|
):
|
||||||
|
# If so, convert to list
|
||||||
|
args_dict[key] = [v for v in val]
|
||||||
|
else:
|
||||||
|
args_dict[key] = val
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=422, message=str(e))
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
|
|
|
@ -590,19 +590,20 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
def test_gemini_pro_function_calling():
|
async def test_gemini_pro_function_calling(sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
response = litellm.completion(
|
data = {
|
||||||
model="vertex_ai/gemini-pro",
|
"model": "vertex_ai/gemini-pro",
|
||||||
messages=[
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Call the submit_cities function with San Francisco and New York",
|
"content": "Call the submit_cities function with San Francisco and New York",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
tools=[
|
"tools": [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
|
@ -618,11 +619,13 @@ def test_gemini_pro_function_calling():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
}
|
||||||
|
if sync_mode:
|
||||||
|
response = litellm.completion(**data)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except litellm.APIError as e:
|
|
||||||
pass
|
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue