mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_ai.py): support function calling for gemini
This commit is contained in:
parent
a1484171b5
commit
86403cd14e
3 changed files with 167 additions and 95 deletions
|
@ -5,7 +5,7 @@ import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
import litellm
|
import litellm, uuid
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@ -264,14 +264,13 @@ def completion(
|
||||||
|
|
||||||
request_str = ""
|
request_str = ""
|
||||||
response_obj = None
|
response_obj = None
|
||||||
if model in litellm.vertex_language_models:
|
if (
|
||||||
|
model in litellm.vertex_language_models
|
||||||
|
or model in litellm.vertex_vision_models
|
||||||
|
):
|
||||||
llm_model = GenerativeModel(model)
|
llm_model = GenerativeModel(model)
|
||||||
mode = ""
|
|
||||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
|
||||||
elif model in litellm.vertex_vision_models:
|
|
||||||
llm_model = GenerativeModel(model)
|
|
||||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
|
||||||
mode = "vision"
|
mode = "vision"
|
||||||
|
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||||
elif model in litellm.vertex_chat_models:
|
elif model in litellm.vertex_chat_models:
|
||||||
llm_model = ChatModel.from_pretrained(model)
|
llm_model = ChatModel.from_pretrained(model)
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
|
@ -318,48 +317,10 @@ def completion(
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mode == "":
|
if mode == "vision":
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
|
||||||
stream = optional_params.pop("stream")
|
|
||||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
model_response = llm_model.generate_content(
|
|
||||||
prompt,
|
|
||||||
generation_config=GenerationConfig(**optional_params),
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
optional_params["stream"] = True
|
|
||||||
return model_response
|
|
||||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response_obj = llm_model.generate_content(
|
|
||||||
prompt,
|
|
||||||
generation_config=GenerationConfig(**optional_params),
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
)
|
|
||||||
completion_response = response_obj.text
|
|
||||||
response_obj = response_obj._raw_response
|
|
||||||
elif mode == "vision":
|
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
tools = optional_params.pop("tools", None)
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
content = [prompt] + images
|
content = [prompt] + images
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
@ -379,6 +340,7 @@ def completion(
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=GenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
return model_response
|
return model_response
|
||||||
|
@ -399,9 +361,35 @@ def completion(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=GenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
completion_response = response.text
|
|
||||||
|
if tools is not None and hasattr(
|
||||||
|
response.candidates[0].content.parts[0], "function_call"
|
||||||
|
):
|
||||||
|
function_call = response.candidates[0].content.parts[0].function_call
|
||||||
|
args_dict = {}
|
||||||
|
for k, v in function_call.args.items():
|
||||||
|
args_dict[k] = v
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
message = litellm.Message(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"arguments": args_str,
|
||||||
|
"name": function_call.name,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
completion_response = message
|
||||||
|
else:
|
||||||
|
completion_response = response.text
|
||||||
response_obj = response._raw_response
|
response_obj = response._raw_response
|
||||||
|
optional_params["tools"] = tools
|
||||||
elif mode == "chat":
|
elif mode == "chat":
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
request_str += f"chat = llm_model.start_chat()\n"
|
request_str += f"chat = llm_model.start_chat()\n"
|
||||||
|
@ -479,7 +467,9 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
if len(str(completion_response)) > 0:
|
if isinstance(completion_response, litellm.Message):
|
||||||
|
model_response["choices"][0]["message"] = completion_response
|
||||||
|
elif len(str(completion_response)) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
|
@ -533,26 +523,10 @@ async def async_completion(
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
if mode == "":
|
if mode == "vision":
|
||||||
# gemini-pro
|
|
||||||
chat = llm_model.start_chat()
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response_obj = await chat.send_message_async(
|
|
||||||
prompt, generation_config=GenerationConfig(**optional_params)
|
|
||||||
)
|
|
||||||
completion_response = response_obj.text
|
|
||||||
response_obj = response_obj._raw_response
|
|
||||||
elif mode == "vision":
|
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
tools = optional_params.pop("tools", None)
|
||||||
|
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
content = [prompt] + images
|
content = [prompt] + images
|
||||||
|
@ -570,10 +544,37 @@ async def async_completion(
|
||||||
|
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = await llm_model._generate_content_async(
|
response = await llm_model._generate_content_async(
|
||||||
contents=content, generation_config=GenerationConfig(**optional_params)
|
contents=content,
|
||||||
|
generation_config=GenerationConfig(**optional_params),
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
completion_response = response.text
|
|
||||||
|
if tools is not None and hasattr(
|
||||||
|
response.candidates[0].content.parts[0], "function_call"
|
||||||
|
):
|
||||||
|
function_call = response.candidates[0].content.parts[0].function_call
|
||||||
|
args_dict = {}
|
||||||
|
for k, v in function_call.args.items():
|
||||||
|
args_dict[k] = v
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
message = litellm.Message(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"arguments": args_str,
|
||||||
|
"name": function_call.name,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
completion_response = message
|
||||||
|
else:
|
||||||
|
completion_response = response.text
|
||||||
response_obj = response._raw_response
|
response_obj = response._raw_response
|
||||||
|
optional_params["tools"] = tools
|
||||||
elif mode == "chat":
|
elif mode == "chat":
|
||||||
# chat-bison etc.
|
# chat-bison etc.
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
|
@ -609,7 +610,9 @@ async def async_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
if len(str(completion_response)) > 0:
|
if isinstance(completion_response, litellm.Message):
|
||||||
|
model_response["choices"][0]["message"] = completion_response
|
||||||
|
elif len(str(completion_response)) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
|
@ -661,33 +664,14 @@ async def async_streaming(
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
if mode == "":
|
if mode == "vision":
|
||||||
# gemini-pro
|
|
||||||
chat = llm_model.start_chat()
|
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
tools = optional_params.pop("tools", None)
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = await chat.send_message_async(
|
|
||||||
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
|
|
||||||
)
|
|
||||||
optional_params["stream"] = True
|
|
||||||
elif mode == "vision":
|
|
||||||
stream = optional_params.pop("stream")
|
|
||||||
|
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
content = [prompt] + images
|
content = [prompt] + images
|
||||||
stream = optional_params.pop("stream")
|
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -698,12 +682,13 @@ async def async_streaming(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = llm_model._generate_content_streaming_async(
|
response = await llm_model._generate_content_streaming_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=GenerationConfig(**optional_params),
|
||||||
stream=True,
|
tools=tools,
|
||||||
)
|
)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
|
optional_params["tools"] = tools
|
||||||
elif mode == "chat":
|
elif mode == "chat":
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
optional_params.pop(
|
optional_params.pop(
|
||||||
|
|
|
@ -98,7 +98,8 @@ def test_vertex_ai():
|
||||||
litellm.vertex_project = "reliablekeys"
|
litellm.vertex_project = "reliablekeys"
|
||||||
|
|
||||||
test_models = random.sample(test_models, 1)
|
test_models = random.sample(test_models, 1)
|
||||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
# test_models += litellm.vertex_language_models # always test gemini-pro
|
||||||
|
test_models = litellm.vertex_language_models # always test gemini-pro
|
||||||
for model in test_models:
|
for model in test_models:
|
||||||
try:
|
try:
|
||||||
if model in [
|
if model in [
|
||||||
|
@ -303,6 +304,69 @@ def test_gemini_pro_vision():
|
||||||
# test_gemini_pro_vision()
|
# test_gemini_pro_vision()
|
||||||
|
|
||||||
|
|
||||||
|
def gemini_pro_function_calling():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
completion = litellm.completion(
|
||||||
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
|
)
|
||||||
|
print(f"completion: {completion}")
|
||||||
|
|
||||||
|
|
||||||
|
# gemini_pro_function_calling()
|
||||||
|
|
||||||
|
|
||||||
|
async def gemini_pro_async_function_calling():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
completion = await litellm.acompletion(
|
||||||
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
|
)
|
||||||
|
print(f"completion: {completion}")
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(gemini_pro_async_function_calling())
|
||||||
|
|
||||||
# Extra gemini Vision tests for completion + stream, async, async + stream
|
# Extra gemini Vision tests for completion + stream, async, async + stream
|
||||||
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
|
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
|
||||||
# def test_gemini_pro_vision_stream():
|
# def test_gemini_pro_vision_stream():
|
||||||
|
|
|
@ -2939,6 +2939,7 @@ def get_optional_params(
|
||||||
custom_llm_provider != "openai"
|
custom_llm_provider != "openai"
|
||||||
and custom_llm_provider != "text-completion-openai"
|
and custom_llm_provider != "text-completion-openai"
|
||||||
and custom_llm_provider != "azure"
|
and custom_llm_provider != "azure"
|
||||||
|
and custom_llm_provider != "vertex_ai"
|
||||||
):
|
):
|
||||||
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||||
# ollama actually supports json output
|
# ollama actually supports json output
|
||||||
|
@ -3238,7 +3239,14 @@ def get_optional_params(
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["temperature", "top_p", "max_tokens", "stream"]
|
supported_params = [
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
|
@ -3249,6 +3257,21 @@ def get_optional_params(
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
|
if tools is not None and isinstance(tools, list):
|
||||||
|
from vertexai.preview import generative_models
|
||||||
|
|
||||||
|
gtools = []
|
||||||
|
for tool in tools:
|
||||||
|
gtool = generative_models.FunctionDeclaration(
|
||||||
|
name=tool["function"]["name"],
|
||||||
|
description=tool["function"].get("description", ""),
|
||||||
|
parameters=tool["function"].get("parameters", {}),
|
||||||
|
)
|
||||||
|
gtool_func_declaration = generative_models.Tool(
|
||||||
|
function_declarations=[gtool]
|
||||||
|
)
|
||||||
|
gtools.append(gtool_func_declaration)
|
||||||
|
optional_params["tools"] = gtools
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue