mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_ai.py): support function calling for gemini
This commit is contained in:
parent
a1484171b5
commit
86403cd14e
3 changed files with 167 additions and 95 deletions
|
@ -5,7 +5,7 @@ import requests
|
|||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm
|
||||
import litellm, uuid
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -264,14 +264,13 @@ def completion(
|
|||
|
||||
request_str = ""
|
||||
response_obj = None
|
||||
if model in litellm.vertex_language_models:
|
||||
if (
|
||||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
llm_model = GenerativeModel(model)
|
||||
mode = ""
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_vision_models:
|
||||
llm_model = GenerativeModel(model)
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
mode = "vision"
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_chat_models:
|
||||
llm_model = ChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
|
@ -318,48 +317,10 @@ def completion(
|
|||
**optional_params,
|
||||
)
|
||||
|
||||
if mode == "":
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
model_response = llm_model.generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=stream,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = llm_model.generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
|
@ -379,6 +340,7 @@ def completion(
|
|||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
@ -399,9 +361,35 @@ def completion(
|
|||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
tools=tools,
|
||||
)
|
||||
completion_response = response.text
|
||||
|
||||
if tools is not None and hasattr(
|
||||
response.candidates[0].content.parts[0], "function_call"
|
||||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
args_str = json.dumps(args_dict)
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"arguments": args_str,
|
||||
"name": function_call.name,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
completion_response = message
|
||||
else:
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
optional_params["tools"] = tools
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
request_str += f"chat = llm_model.start_chat()\n"
|
||||
|
@ -479,7 +467,9 @@ def completion(
|
|||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if len(str(completion_response)) > 0:
|
||||
if isinstance(completion_response, litellm.Message):
|
||||
model_response["choices"][0]["message"] = completion_response
|
||||
elif len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
|
@ -533,26 +523,10 @@ async def async_completion(
|
|||
try:
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await chat.send_message_async(
|
||||
prompt, generation_config=GenerationConfig(**optional_params)
|
||||
)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
|
@ -570,10 +544,37 @@ async def async_completion(
|
|||
|
||||
## LLM Call
|
||||
response = await llm_model._generate_content_async(
|
||||
contents=content, generation_config=GenerationConfig(**optional_params)
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
tools=tools,
|
||||
)
|
||||
completion_response = response.text
|
||||
|
||||
if tools is not None and hasattr(
|
||||
response.candidates[0].content.parts[0], "function_call"
|
||||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
args_str = json.dumps(args_dict)
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"arguments": args_str,
|
||||
"name": function_call.name,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
completion_response = message
|
||||
else:
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
optional_params["tools"] = tools
|
||||
elif mode == "chat":
|
||||
# chat-bison etc.
|
||||
chat = llm_model.start_chat()
|
||||
|
@ -609,7 +610,9 @@ async def async_completion(
|
|||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if len(str(completion_response)) > 0:
|
||||
if isinstance(completion_response, litellm.Message):
|
||||
model_response["choices"][0]["message"] = completion_response
|
||||
elif len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
|
@ -661,33 +664,14 @@ async def async_streaming(
|
|||
"""
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
if mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = await chat.send_message_async(
|
||||
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -698,12 +682,13 @@ async def async_streaming(
|
|||
},
|
||||
)
|
||||
|
||||
response = llm_model._generate_content_streaming_async(
|
||||
response = await llm_model._generate_content_streaming_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
optional_params["tools"] = tools
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
optional_params.pop(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue