Update condition in tests to handle llama-3.1 vs llama3.1 (HF names)

This commit is contained in:
Ashwin Bharambe 2024-11-19 13:25:36 -08:00
parent 394519d68a
commit 05d1ead02f

View file

@ -25,7 +25,11 @@ from .utils import group_chunks
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
@ -34,7 +38,7 @@ def common_params(inference_model):
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in inference_model
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}