default text model

This commit is contained in:
Xi Yan 2025-03-05 16:24:43 -08:00
parent 5d43b9157e
commit 54abeeebce

View file

@ -86,7 +86,7 @@ def pytest_addoption(parser):
) )
parser.addoption( parser.addoption(
"--judge-model", "--judge-model",
default=None, default=TEXT_MODEL,
help="Specify the judge model to use for testing", help="Specify the judge model to use for testing",
) )
parser.addoption( parser.addoption(
@ -230,10 +230,16 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
# Replace the methods with recordable mocks # Replace the methods with recordable mocks
inference_mock.chat_completion = RecordableMock( inference_mock.chat_completion = RecordableMock(
original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses original_inference.chat_completion,
cache_dir,
"chat_completion",
record=record_responses,
) )
inference_mock.completion = RecordableMock( inference_mock.completion = RecordableMock(
original_inference.completion, cache_dir, "text_completion", record=record_responses original_inference.completion,
cache_dir,
"text_completion",
record=record_responses,
) )
inference_mock.embeddings = RecordableMock( inference_mock.embeddings = RecordableMock(
original_inference.embeddings, cache_dir, "embeddings", record=record_responses original_inference.embeddings, cache_dir, "embeddings", record=record_responses
@ -247,7 +253,10 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
# Replace the methods with recordable mocks # Replace the methods with recordable mocks
tool_runtime_mock.invoke_tool = RecordableMock( tool_runtime_mock.invoke_tool = RecordableMock(
original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses original_tool_runtime_api.invoke_tool,
cache_dir,
"invoke_tool",
record=record_responses,
) )
agents_impl.tool_runtime_api = tool_runtime_mock agents_impl.tool_runtime_api = tool_runtime_mock
@ -267,7 +276,12 @@ def inference_provider_type(llama_stack_client):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def client_with_models( def client_with_models(
llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id llama_stack_client,
text_model_id,
vision_model_id,
embedding_model_id,
embedding_dimension,
judge_model_id,
): ):
client = llama_stack_client client = llama_stack_client