Fix fireworks and update the test

Don't look for eom_id / eot_id sadly since providers don't return the
last token
This commit is contained in:
Ashwin Bharambe 2024-10-07 17:43:47 -07:00 committed by Ashwin Bharambe
parent bbd3a02615
commit dba7caf1d0
4 changed files with 37 additions and 37 deletions

View file

@ -222,8 +222,9 @@ async def test_chat_completion_with_tool_calling(
message = response[0].completion_message
stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
assert message.stop_reason == stop_reason
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
@ -266,10 +267,12 @@ async def test_chat_completion_with_tool_calling_streaming(
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
expected_stop_reason = get_expected_stop_reason(
inference_settings["common_params"]["model"]
)
assert end.event.stop_reason == expected_stop_reason
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
@ -281,7 +284,7 @@ async def test_chat_completion_with_tool_calling_streaming(
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
assert last.event.stop_reason == expected_stop_reason
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)