mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(vertex_ai.py): fix vertex ai function calling
This commit is contained in:
parent
21e1c2dc21
commit
9d17a0789f
3 changed files with 20 additions and 19 deletions
|
@ -559,8 +559,7 @@ def completion(
|
|||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances
|
||||
endpoint=endpoint_path, instances=instances
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
|
@ -585,12 +584,8 @@ def completion(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
instances=instances
|
||||
).predictions
|
||||
request_str += f"llm_model.predict(instances={instances})\n"
|
||||
response = llm_model.predict(instances=instances).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
|
@ -614,7 +609,6 @@ def completion(
|
|||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
|
@ -766,6 +760,7 @@ async def async_completion(
|
|||
Vertex AI Model Garden
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -797,11 +792,9 @@ async def async_completion(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
|
||||
elif mode == "private":
|
||||
request_str += (
|
||||
f"llm_model.predict_async(instances={instances})\n"
|
||||
)
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
|
@ -826,7 +819,6 @@ async def async_completion(
|
|||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
|
@ -954,6 +946,7 @@ async def async_streaming(
|
|||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
elif mode == "custom":
|
||||
from google.cloud import aiplatform
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## LOGGING
|
||||
|
@ -972,7 +965,9 @@ async def async_streaming(
|
|||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
request_str += (
|
||||
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue