mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(vertex_ai.py): fix vertex ai function calling
This commit is contained in:
parent
21e1c2dc21
commit
9d17a0789f
3 changed files with 20 additions and 19 deletions
|
@ -559,8 +559,7 @@ def completion(
|
|||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances
|
||||
endpoint=endpoint_path, instances=instances
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
|
@ -585,12 +584,8 @@ def completion(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
instances=instances
|
||||
).predictions
|
||||
request_str += f"llm_model.predict(instances={instances})\n"
|
||||
response = llm_model.predict(instances=instances).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
|
@ -614,7 +609,6 @@ def completion(
|
|||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
|
@ -766,6 +760,7 @@ async def async_completion(
|
|||
Vertex AI Model Garden
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -797,11 +792,9 @@ async def async_completion(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
|
||||
elif mode == "private":
|
||||
request_str += (
|
||||
f"llm_model.predict_async(instances={instances})\n"
|
||||
)
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
|
@ -826,7 +819,6 @@ async def async_completion(
|
|||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
|
@ -954,6 +946,7 @@ async def async_streaming(
|
|||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
elif mode == "custom":
|
||||
from google.cloud import aiplatform
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## LOGGING
|
||||
|
@ -972,7 +965,9 @@ async def async_streaming(
|
|||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
request_str += (
|
||||
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
|
|
|
@ -318,7 +318,7 @@ def test_gemini_pro_vision():
|
|||
# test_gemini_pro_vision()
|
||||
|
||||
|
||||
def gemini_pro_function_calling():
|
||||
def test_gemini_pro_function_calling():
|
||||
load_vertex_ai_credentials()
|
||||
tools = [
|
||||
{
|
||||
|
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
|
|||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
print(f"completion: {completion}")
|
||||
assert completion.choices[0].message.content is None
|
||||
assert len(completion.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
# gemini_pro_function_calling()
|
||||
|
||||
|
||||
async def gemini_pro_async_function_calling():
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_async_function_calling():
|
||||
load_vertex_ai_credentials()
|
||||
tools = [
|
||||
{
|
||||
|
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
|
|||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||
)
|
||||
print(f"completion: {completion}")
|
||||
assert completion.choices[0].message.content is None
|
||||
assert len(completion.choices[0].message.tool_calls) == 1
|
||||
# raise Exception("it worked!")
|
||||
|
||||
|
||||
# asyncio.run(gemini_pro_async_function_calling())
|
||||
|
|
|
@ -4274,8 +4274,8 @@ def get_optional_params(
|
|||
optional_params["stop_sequences"] = stop
|
||||
if max_tokens is not None:
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
elif custom_llm_provider == "vertex_ai" and model in (
|
||||
litellm.vertex_chat_models
|
||||
elif custom_llm_provider == "vertex_ai" and (
|
||||
model in litellm.vertex_chat_models
|
||||
or model in litellm.vertex_code_chat_models
|
||||
or model in litellm.vertex_text_models
|
||||
or model in litellm.vertex_code_text_models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue