fix(vertex_ai.py): support function calling for gemini

This commit is contained in:
Krrish Dholakia 2023-12-28 19:06:49 +05:30
parent a1484171b5
commit 86403cd14e
3 changed files with 167 additions and 95 deletions

View file

@ -5,7 +5,7 @@ import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
import litellm, uuid
import httpx
@ -264,14 +264,13 @@ def completion(
request_str = ""
response_obj = None
if model in litellm.vertex_language_models:
if (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model = GenerativeModel(model)
mode = ""
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_vision_models:
llm_model = GenerativeModel(model)
request_str += f"llm_model = GenerativeModel({model})\n"
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
llm_model = ChatModel.from_pretrained(model)
mode = "chat"
@ -318,48 +317,10 @@ def completion(
**optional_params,
)
if mode == "":
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True:
@ -379,6 +340,7 @@ def completion(
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True,
tools=tools,
)
optional_params["stream"] = True
return model_response
@ -399,9 +361,35 @@ def completion(
contents=content,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
tools=tools,
)
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
request_str += f"chat = llm_model.start_chat()\n"
@ -479,7 +467,9 @@ def completion(
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
@ -533,26 +523,10 @@ async def async_completion(
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
@ -570,10 +544,37 @@ async def async_completion(
## LLM Call
response = await llm_model._generate_content_async(
contents=content, generation_config=GenerationConfig(**optional_params)
contents=content,
generation_config=GenerationConfig(**optional_params),
tools=tools,
)
completion_response = response.text
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
@ -609,7 +610,9 @@ async def async_completion(
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response
elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
@ -661,33 +664,14 @@ async def async_streaming(
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
if mode == "vision":
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True
elif mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@ -698,12 +682,13 @@ async def async_streaming(
},
)
response = llm_model._generate_content_streaming_async(
response = await llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True,
tools=tools,
)
optional_params["stream"] = True
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(