fix(vertex_ai.py): support function calling for gemini

This commit is contained in:
Krrish Dholakia 2023-12-28 19:06:49 +05:30
parent a1484171b5
commit 86403cd14e
3 changed files with 167 additions and 95 deletions

View file

@ -5,7 +5,7 @@ import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
import litellm, uuid
import httpx
@ -264,14 +264,13 @@ def completion(
request_str = ""
response_obj = None
if model in litellm.vertex_language_models:
if (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model = GenerativeModel(model)
mode = ""
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_vision_models:
llm_model = GenerativeModel(model)
request_str += f"llm_model = GenerativeModel({model})\n"
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
llm_model = ChatModel.from_pretrained(model)
mode = "chat"
@ -318,48 +317,10 @@ def completion(
**optional_params,
)
if mode == "":
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True:
@ -379,6 +340,7 @@ def completion(
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True,
tools=tools,
)
optional_params["stream"] = True
return model_response
@ -399,9 +361,35 @@ def completion(
contents=content,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
tools=tools,
)
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
request_str += f"chat = llm_model.start_chat()\n"
@ -479,7 +467,9 @@ def completion(
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
@ -533,26 +523,10 @@ async def async_completion(
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
@ -570,10 +544,37 @@ async def async_completion(
## LLM Call
response = await llm_model._generate_content_async(
contents=content, generation_config=GenerationConfig(**optional_params)
contents=content,
generation_config=GenerationConfig(**optional_params),
tools=tools,
)
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
@ -609,7 +610,9 @@ async def async_completion(
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
@ -661,33 +664,14 @@ async def async_streaming(
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
if mode == "vision":
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True
elif mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@ -698,12 +682,13 @@ async def async_streaming(
},
)
response = llm_model._generate_content_streaming_async(
response = await llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True,
tools=tools,
)
optional_params["stream"] = True
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(

View file

@ -98,7 +98,8 @@ def test_vertex_ai():
litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
# test_models += litellm.vertex_language_models # always test gemini-pro
test_models = litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
if model in [
@ -303,6 +304,69 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision()
def gemini_pro_function_calling():
load_vertex_ai_credentials()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
# gemini_pro_function_calling()
async def gemini_pro_async_function_calling():
load_vertex_ai_credentials()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
completion = await litellm.acompletion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
asyncio.run(gemini_pro_async_function_calling())
# Extra gemini Vision tests for completion + stream, async, async + stream
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
# def test_gemini_pro_vision_stream():

View file

@ -2939,6 +2939,7 @@ def get_optional_params(
custom_llm_provider != "openai"
and custom_llm_provider != "text-completion-openai"
and custom_llm_provider != "azure"
and custom_llm_provider != "vertex_ai"
):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output
@ -3238,7 +3239,14 @@ def get_optional_params(
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai":
## check if unsupported param passed in
supported_params = ["temperature", "top_p", "max_tokens", "stream"]
supported_params = [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
@ -3249,6 +3257,21 @@ def get_optional_params(
optional_params["stream"] = stream
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models
gtools = []
for tool in tools:
gtool = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declaration = generative_models.Tool(
function_declarations=[gtool]
)
gtools.append(gtool_func_declaration)
optional_params["tools"] = gtools
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]