fix(vertex_ai.py): support function calling for gemini

This commit is contained in:
Krrish Dholakia 2023-12-28 19:06:49 +05:30
parent a1484171b5
commit 86403cd14e
3 changed files with 167 additions and 95 deletions

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm, uuid
import httpx import httpx
@ -264,14 +264,13 @@ def completion(
request_str = "" request_str = ""
response_obj = None response_obj = None
if model in litellm.vertex_language_models: if (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model = GenerativeModel(model) llm_model = GenerativeModel(model)
mode = ""
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_vision_models:
llm_model = GenerativeModel(model)
request_str += f"llm_model = GenerativeModel({model})\n"
mode = "vision" mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models: elif model in litellm.vertex_chat_models:
llm_model = ChatModel.from_pretrained(model) llm_model = ChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
@ -318,48 +317,10 @@ def completion(
**optional_params, **optional_params,
) )
if mode == "": if mode == "vision":
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages) prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
@ -379,6 +340,7 @@ def completion(
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
stream=True, stream=True,
tools=tools,
) )
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
@ -399,9 +361,35 @@ def completion(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
tools=tools,
) )
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str += f"chat = llm_model.start_chat()\n" request_str += f"chat = llm_model.start_chat()\n"
@ -479,7 +467,9 @@ def completion(
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
@ -533,26 +523,10 @@ async def async_completion(
try: try:
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "vision":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages) prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images content = [prompt] + images
@ -570,10 +544,37 @@ async def async_completion(
## LLM Call ## LLM Call
response = await llm_model._generate_content_async( response = await llm_model._generate_content_async(
contents=content, generation_config=GenerationConfig(**optional_params) contents=content,
generation_config=GenerationConfig(**optional_params),
tools=tools,
) )
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat": elif mode == "chat":
# chat-bison etc. # chat-bison etc.
chat = llm_model.start_chat() chat = llm_model.start_chat()
@ -609,7 +610,9 @@ async def async_completion(
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
@ -661,33 +664,14 @@ async def async_streaming(
""" """
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "vision":
# gemini-pro
chat = llm_model.start_chat()
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" tools = optional_params.pop("tools", None)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True
elif mode == "vision":
stream = optional_params.pop("stream")
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages) prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -698,12 +682,13 @@ async def async_streaming(
}, },
) )
response = llm_model._generate_content_streaming_async( response = await llm_model._generate_content_streaming_async(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
stream=True, tools=tools,
) )
optional_params["stream"] = True optional_params["stream"] = True
optional_params["tools"] = tools
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
optional_params.pop( optional_params.pop(

View file

@ -98,7 +98,8 @@ def test_vertex_ai():
litellm.vertex_project = "reliablekeys" litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro # test_models += litellm.vertex_language_models # always test gemini-pro
test_models = litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in [ if model in [
@ -303,6 +304,69 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision() # test_gemini_pro_vision()
def gemini_pro_function_calling():
load_vertex_ai_credentials()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
# gemini_pro_function_calling()
async def gemini_pro_async_function_calling():
load_vertex_ai_credentials()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
completion = await litellm.acompletion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
asyncio.run(gemini_pro_async_function_calling())
# Extra gemini Vision tests for completion + stream, async, async + stream # Extra gemini Vision tests for completion + stream, async, async + stream
# if we run into issues with gemini, we will also add these to our ci/cd pipeline # if we run into issues with gemini, we will also add these to our ci/cd pipeline
# def test_gemini_pro_vision_stream(): # def test_gemini_pro_vision_stream():

View file

@ -2939,6 +2939,7 @@ def get_optional_params(
custom_llm_provider != "openai" custom_llm_provider != "openai"
and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "text-completion-openai"
and custom_llm_provider != "azure" and custom_llm_provider != "azure"
and custom_llm_provider != "vertex_ai"
): ):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output # ollama actually supports json output
@ -3238,7 +3239,14 @@ def get_optional_params(
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["temperature", "top_p", "max_tokens", "stream"] supported_params = [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if temperature is not None: if temperature is not None:
@ -3249,6 +3257,21 @@ def get_optional_params(
optional_params["stream"] = stream optional_params["stream"] = stream
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models
gtools = []
for tool in tools:
gtool = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declaration = generative_models.Tool(
function_declarations=[gtool]
)
gtools.append(gtool_func_declaration)
optional_params["tools"] = gtools
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]