fix(vertex_ai.py): fix vertex ai function calling

This commit is contained in:
Krrish Dholakia 2024-02-22 21:28:12 -08:00
parent 21e1c2dc21
commit 9d17a0789f
3 changed files with 20 additions and 19 deletions

View file

@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
) )
response = llm_model.predict( response = llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path, instances=instances
instances=instances
).predictions ).predictions
completion_response = response[0] completion_response = response[0]
@ -585,12 +584,8 @@ def completion(
"request_str": request_str, "request_str": request_str,
}, },
) )
request_str += ( request_str += f"llm_model.predict(instances={instances})\n"
f"llm_model.predict(instances={instances})\n" response = llm_model.predict(instances=instances).predictions
)
response = llm_model.predict(
instances=instances
).predictions
completion_response = response[0] completion_response = response[0]
if ( if (
@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden Vertex AI Model Garden
""" """
from google.cloud import aiplatform from google.cloud import aiplatform
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -797,11 +792,9 @@ async def async_completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private": elif mode == "private":
request_str += ( request_str += f"llm_model.predict_async(instances={instances})\n"
f"llm_model.predict_async(instances={instances})\n"
)
response_obj = await llm_model.predict_async( response_obj = await llm_model.predict_async(
instances=instances, instances=instances,
) )
@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform from google.cloud import aiplatform
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
## LOGGING ## LOGGING
@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path( endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model project=vertex_project, location=vertex_location, endpoint=model
) )
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict( response_obj = await llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path,
instances=instances, instances=instances,

View file

@ -318,7 +318,7 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision() # test_gemini_pro_vision()
def gemini_pro_function_calling(): def test_gemini_pro_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# gemini_pro_function_calling() # gemini_pro_function_calling()
async def gemini_pro_async_function_calling(): @pytest.mark.asyncio
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# raise Exception("it worked!")
# asyncio.run(gemini_pro_async_function_calling()) # asyncio.run(gemini_pro_async_function_calling())

View file

@ -4274,8 +4274,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and model in ( elif custom_llm_provider == "vertex_ai" and (
litellm.vertex_chat_models model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models or model in litellm.vertex_code_text_models