fix(vertex_ai.py): handle stream=false

also adds unit testing for vertex ai calls with langchain
This commit is contained in:
Krrish Dholakia 2024-04-25 13:59:15 -07:00
parent cff83c720d
commit 5f8d88d363
2 changed files with 45 additions and 16 deletions

View file

@ -143,7 +143,9 @@ class VertexAIConfig:
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "stream": if (
param == "stream" and value == True
): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value optional_params["stream"] = value
if param == "n": if param == "n":
optional_params["candidate_count"] = value optional_params["candidate_count"] = value
@ -541,8 +543,9 @@ def completion(
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages) prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream", False)
stream = optional_params.pop("stream") if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -820,6 +823,7 @@ async def async_completion(
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages) prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images content = [prompt] + images

View file

@ -636,7 +636,10 @@ def test_gemini_pro_function_calling():
# gemini_pro_function_calling() # gemini_pro_function_calling()
def test_gemini_pro_function_calling_streaming(): @pytest.mark.parametrize("stream", [False, True])
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_gemini_pro_function_calling_streaming(stream, sync_mode):
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [
@ -665,19 +668,41 @@ def test_gemini_pro_function_calling_streaming():
"content": "What's the weather like in Boston today in fahrenheit?", "content": "What's the weather like in Boston today in fahrenheit?",
} }
] ]
optional_params = {
"tools": tools,
"tool_choice": "auto",
"n": 1,
"stream": stream,
"temperature": 0.1,
}
try: try:
completion = litellm.completion( if sync_mode == True:
model="gemini-pro", response = litellm.completion(
messages=messages, model="gemini-pro", messages=messages, **optional_params
tools=tools, )
tool_choice="auto", print(f"completion: {response}")
stream=True,
) if stream == True:
print(f"completion: {completion}") # assert completion.choices[0].message.content is None
# assert completion.choices[0].message.content is None # assert len(completion.choices[0].message.tool_calls) == 1
# assert len(completion.choices[0].message.tool_calls) == 1 for chunk in response:
for chunk in completion: assert isinstance(chunk, litellm.ModelResponse)
print(f"chunk: {chunk}") else:
assert isinstance(response, litellm.ModelResponse)
else:
response = await litellm.acompletion(
model="gemini-pro", messages=messages, **optional_params
)
print(f"completion: {response}")
if stream == True:
# assert completion.choices[0].message.content is None
# assert len(completion.choices[0].message.tool_calls) == 1
async for chunk in response:
print(f"chunk: {chunk}")
assert isinstance(chunk, litellm.ModelResponse)
else:
assert isinstance(response, litellm.ModelResponse)
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e: