mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(vertex_ai.py): handle stream=false
also adds unit testing for vertex ai calls with langchain
This commit is contained in:
parent
cff83c720d
commit
5f8d88d363
2 changed files with 45 additions and 16 deletions
|
@ -143,7 +143,9 @@ class VertexAIConfig:
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "stream":
|
if (
|
||||||
|
param == "stream" and value == True
|
||||||
|
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "n":
|
if param == "n":
|
||||||
optional_params["candidate_count"] = value
|
optional_params["candidate_count"] = value
|
||||||
|
@ -541,8 +543,9 @@ def completion(
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
content = [prompt] + images
|
content = [prompt] + images
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
stream = optional_params.pop("stream", False)
|
||||||
stream = optional_params.pop("stream")
|
if stream == True:
|
||||||
|
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -820,6 +823,7 @@ async def async_completion(
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
content = [prompt] + images
|
content = [prompt] + images
|
||||||
|
|
|
@ -636,7 +636,10 @@ def test_gemini_pro_function_calling():
|
||||||
# gemini_pro_function_calling()
|
# gemini_pro_function_calling()
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_pro_function_calling_streaming():
|
@pytest.mark.parametrize("stream", [False, True])
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_function_calling_streaming(stream, sync_mode):
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
|
@ -665,19 +668,41 @@ def test_gemini_pro_function_calling_streaming():
|
||||||
"content": "What's the weather like in Boston today in fahrenheit?",
|
"content": "What's the weather like in Boston today in fahrenheit?",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
optional_params = {
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"n": 1,
|
||||||
|
"stream": stream,
|
||||||
|
"temperature": 0.1,
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
completion = litellm.completion(
|
if sync_mode == True:
|
||||||
model="gemini-pro",
|
response = litellm.completion(
|
||||||
messages=messages,
|
model="gemini-pro", messages=messages, **optional_params
|
||||||
tools=tools,
|
)
|
||||||
tool_choice="auto",
|
print(f"completion: {response}")
|
||||||
stream=True,
|
|
||||||
)
|
if stream == True:
|
||||||
print(f"completion: {completion}")
|
# assert completion.choices[0].message.content is None
|
||||||
# assert completion.choices[0].message.content is None
|
# assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
# assert len(completion.choices[0].message.tool_calls) == 1
|
for chunk in response:
|
||||||
for chunk in completion:
|
assert isinstance(chunk, litellm.ModelResponse)
|
||||||
print(f"chunk: {chunk}")
|
else:
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gemini-pro", messages=messages, **optional_params
|
||||||
|
)
|
||||||
|
print(f"completion: {response}")
|
||||||
|
|
||||||
|
if stream == True:
|
||||||
|
# assert completion.choices[0].message.content is None
|
||||||
|
# assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
assert isinstance(chunk, litellm.ModelResponse)
|
||||||
|
else:
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue