fix(vertex_ai.py): fix vertex ai function calling

This commit is contained in:
Krrish Dholakia 2024-02-22 21:28:12 -08:00
parent 21e1c2dc21
commit 9d17a0789f
3 changed files with 20 additions and 19 deletions

View file

@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
endpoint=endpoint_path,
instances=instances
endpoint=endpoint_path, instances=instances
).predictions
completion_response = response[0]
@ -585,12 +584,8 @@ def completion(
"request_str": request_str,
},
)
request_str += (
f"llm_model.predict(instances={instances})\n"
)
response = llm_model.predict(
instances=instances
).predictions
request_str += f"llm_model.predict(instances={instances})\n"
response = llm_model.predict(instances=instances).predictions
completion_response = response[0]
if (
@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden
"""
from google.cloud import aiplatform
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -799,9 +794,7 @@ async def async_completion(
completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private":
request_str += (
f"llm_model.predict_async(instances={instances})\n"
)
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform
stream = optional_params.pop("stream", None)
## LOGGING
@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,

View file

@ -318,7 +318,7 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision()
def gemini_pro_function_calling():
def test_gemini_pro_function_calling():
load_vertex_ai_credentials()
tools = [
{
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# gemini_pro_function_calling()
async def gemini_pro_async_function_calling():
@pytest.mark.asyncio
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials()
tools = [
{
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# raise Exception("it worked!")
# asyncio.run(gemini_pro_async_function_calling())

View file

@ -4274,8 +4274,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and model in (
litellm.vertex_chat_models
elif custom_llm_provider == "vertex_ai" and (
model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models