From 54abeeebce14ddb4d03198f659753540463b3259 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Mar 2025 16:24:43 -0800 Subject: [PATCH] default text model --- tests/integration/conftest.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index dada5449f..ade1893f7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -86,7 +86,7 @@ def pytest_addoption(parser): ) parser.addoption( "--judge-model", - default=None, + default=TEXT_MODEL, help="Specify the judge model to use for testing", ) parser.addoption( @@ -230,10 +230,16 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request): # Replace the methods with recordable mocks inference_mock.chat_completion = RecordableMock( - original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses + original_inference.chat_completion, + cache_dir, + "chat_completion", + record=record_responses, ) inference_mock.completion = RecordableMock( - original_inference.completion, cache_dir, "text_completion", record=record_responses + original_inference.completion, + cache_dir, + "text_completion", + record=record_responses, ) inference_mock.embeddings = RecordableMock( original_inference.embeddings, cache_dir, "embeddings", record=record_responses @@ -247,7 +253,10 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request): # Replace the methods with recordable mocks tool_runtime_mock.invoke_tool = RecordableMock( - original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses + original_tool_runtime_api.invoke_tool, + cache_dir, + "invoke_tool", + record=record_responses, ) agents_impl.tool_runtime_api = tool_runtime_mock @@ -267,7 +276,12 @@ def inference_provider_type(llama_stack_client): @pytest.fixture(scope="session") def client_with_models( - llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id + llama_stack_client, + text_model_id, + vision_model_id, + embedding_model_id, + embedding_dimension, + judge_model_id, ): client = llama_stack_client