fix(vertex_ai.py): fix vertex ai function calling

This commit is contained in:
Krrish Dholakia 2024-02-22 21:28:12 -08:00
parent 21e1c2dc21
commit 9d17a0789f
3 changed files with 20 additions and 19 deletions

View file

@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
endpoint=endpoint_path,
instances=instances
endpoint=endpoint_path, instances=instances
).predictions
completion_response = response[0]
@ -585,12 +584,8 @@ def completion(
"request_str": request_str,
},
)
request_str += (
f"llm_model.predict(instances={instances})\n"
)
response = llm_model.predict(
instances=instances
).predictions
request_str += f"llm_model.predict(instances={instances})\n"
response = llm_model.predict(instances=instances).predictions
completion_response = response[0]
if (
@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden
"""
from google.cloud import aiplatform
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -797,11 +792,9 @@ async def async_completion(
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private":
request_str += (
f"llm_model.predict_async(instances={instances})\n"
)
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform
stream = optional_params.pop("stream", None)
## LOGGING
@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,