diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 7b7aca5bd..6e263432a 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -25,7 +25,11 @@ from .utils import group_chunks def get_expected_stop_reason(model: str): - return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn + return ( + StopReason.end_of_message + if ("Llama3.1" in model or "Llama-3.1" in model) + else StopReason.end_of_turn + ) @pytest.fixture @@ -34,7 +38,7 @@ def common_params(inference_model): "tool_choice": ToolChoice.auto, "tool_prompt_format": ( ToolPromptFormat.json - if "Llama3.1" in inference_model + if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model) else ToolPromptFormat.python_list ), }