fix(vertex_ai.py): support tool call list response async completion

This commit is contained in:
Krrish Dholakia 2024-05-13 10:42:31 -07:00
parent 7f6e933372
commit 04ae285001
2 changed files with 28 additions and 11 deletions

View file

@ -867,6 +867,8 @@ async def async_completion(
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
try: try:
import proto # type: ignore
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
@ -901,9 +903,21 @@ async def async_completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
args_str = json.dumps(args_dict) for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[

View file

@ -590,19 +590,20 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
def test_gemini_pro_function_calling(): async def test_gemini_pro_function_calling(sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
response = litellm.completion( data = {
model="vertex_ai/gemini-pro", "model": "vertex_ai/gemini-pro",
messages=[ "messages": [
{ {
"role": "user", "role": "user",
"content": "Call the submit_cities function with San Francisco and New York", "content": "Call the submit_cities function with San Francisco and New York",
} }
], ],
tools=[ "tools": [
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -618,11 +619,13 @@ def test_gemini_pro_function_calling():
}, },
} }
], ],
) }
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
print(f"response: {response}") print(f"response: {response}")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except Exception as e: except Exception as e: