From 332629438ada1f1844f11d04b198a9efe47f71db Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Sat, 1 Mar 2025 17:38:41 -0500 Subject: [PATCH] Address feedback Signed-off-by: Yuan Tang --- tests/client-sdk/conftest.py | 41 ++++++---- .../inference/test_text_inference.py | 76 ++++++++++--------- .../inference/test_vision_inference.py | 14 ++-- 3 files changed, 74 insertions(+), 57 deletions(-) diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index c26f9e157..962373a91 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -6,7 +6,7 @@ import os import pytest -from llama_stack_client import BadRequestError, LlamaStackClient +from llama_stack_client import LlamaStackClient from report import Report from llama_stack import LlamaStackAsLibraryClient @@ -109,28 +109,39 @@ def inference_provider_type(llama_stack_client): return inference_providers[0].provider_type -@pytest.fixture(scope="session") -def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension): - client = llama_stack_client - +def get_model_ids_and_providers(client): providers = [p for p in client.providers.list() if p.api == "inference"] assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] model_ids = {m.identifier for m in client.models.list()} model_ids.update(m.provider_resource_id for m in client.models.list()) + return model_ids, inference_providers, providers - try: - if text_model_id and text_model_id not in model_ids: - client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) - except BadRequestError: - pass - try: - if vision_model_id and vision_model_id not in model_ids: - client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) - except BadRequestError: - pass +@pytest.fixture(scope="session") +def client_with_text_models(llama_stack_client, text_model_id): + client = llama_stack_client + model_ids, inference_providers, providers = get_model_ids_and_providers(llama_stack_client) + + if text_model_id and text_model_id not in model_ids: + client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) + return client + + +@pytest.fixture(scope="session") +def client_with_vision_models(llama_stack_client, vision_model_id): + client = llama_stack_client + model_ids, inference_providers, providers = get_model_ids_and_providers(llama_stack_client) + if vision_model_id and vision_model_id not in model_ids: + client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) + return client + + +@pytest.fixture(scope="session") +def client_with_embedding_models(llama_stack_client, embedding_model_id, embedding_dimension): + client = llama_stack_client + model_ids, inference_providers, providers = get_model_ids_and_providers(llama_stack_client) if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids: # try to find a provider that supports embeddings, if sentence-transformers is not available selected_provider = None diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 63813a1cc..18c796e2e 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -14,18 +14,18 @@ from llama_stack.providers.tests.test_cases.test_case import TestCase PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} -def skip_if_model_doesnt_support_completion(client_with_models, model_id): - models = {m.identifier: m for m in client_with_models.models.list()} +def skip_if_model_doesnt_support_completion(client_with_text_models, model_id): + models = {m.identifier: m for m in client_with_text_models.models.list()} provider_id = models[model_id].provider_id - providers = {p.provider_id: p for p in client_with_models.providers.list()} + providers = {p.provider_id: p for p in client_with_text_models.providers.list()} provider = providers[provider_id] if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") -def get_llama_model(client_with_models, model_id): +def get_llama_model(client_with_text_models, model_id): models = {} - for m in client_with_models.models.list(): + for m in client_with_text_models.models.list(): models[m.identifier] = m models[m.provider_resource_id] = m @@ -46,11 +46,11 @@ def get_llama_model(client_with_models, model_id): "inference:completion:sanity", ], ) -def test_text_completion_non_streaming(client_with_models, text_model_id, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) +def test_text_completion_non_streaming(client_with_text_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_text_models, text_model_id) tc = TestCase(test_case) - response = client_with_models.inference.completion( + response = client_with_text_models.inference.completion( content=tc["content"], stream=False, model_id=text_model_id, @@ -68,11 +68,11 @@ def test_text_completion_non_streaming(client_with_models, text_model_id, test_c "inference:completion:sanity", ], ) -def test_text_completion_streaming(client_with_models, text_model_id, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) +def test_text_completion_streaming(client_with_text_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_text_models, text_model_id) tc = TestCase(test_case) - response = client_with_models.inference.completion( + response = client_with_text_models.inference.completion( content=tc["content"], stream=True, model_id=text_model_id, @@ -92,14 +92,16 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) +def test_text_completion_log_probs_non_streaming( + client_with_text_models, text_model_id, inference_provider_type, test_case +): + skip_if_model_doesnt_support_completion(client_with_text_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") tc = TestCase(test_case) - response = client_with_models.inference.completion( + response = client_with_text_models.inference.completion( content=tc["content"], stream=False, model_id=text_model_id, @@ -121,14 +123,16 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_ "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) +def test_text_completion_log_probs_streaming( + client_with_text_models, text_model_id, inference_provider_type, test_case +): + skip_if_model_doesnt_support_completion(client_with_text_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") tc = TestCase(test_case) - response = client_with_models.inference.completion( + response = client_with_text_models.inference.completion( content=tc["content"], stream=True, model_id=text_model_id, @@ -154,8 +158,8 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id, "inference:completion:structured_output", ], ) -def test_text_completion_structured_output(client_with_models, text_model_id, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) +def test_text_completion_structured_output(client_with_text_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_text_models, text_model_id) class AnswerFormat(BaseModel): name: str @@ -165,7 +169,7 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te tc = TestCase(test_case) user_input = tc["user_input"] - response = client_with_models.inference.completion( + response = client_with_text_models.inference.completion( model_id=text_model_id, content=user_input, stream=False, @@ -191,12 +195,12 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te "inference:chat_completion:non_streaming_02", ], ) -def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_non_streaming(client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=[ { @@ -218,12 +222,12 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t "inference:chat_completion:streaming_02", ], ) -def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_streaming(v, client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=[{"role": "user", "content": question}], stream=True, @@ -239,10 +243,10 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], @@ -276,10 +280,10 @@ def extract_tool_invocation_content(response): "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_calling_and_streaming(client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], @@ -298,10 +302,10 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_choice_required(client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], @@ -322,10 +326,10 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_choice_none(client_with_text_models, text_model_id, test_case): tc = TestCase(test_case) - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], @@ -342,7 +346,7 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod "inference:chat_completion:structured_output", ], ) -def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): +def test_text_chat_completion_structured_output(client_with_text_models, text_model_id, test_case): class NBAStats(BaseModel): year_for_draft: int num_seasons_in_nba: int @@ -355,7 +359,7 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i tc = TestCase(test_case) - response = client_with_models.inference.chat_completion( + response = client_with_text_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], response_format={ @@ -381,7 +385,7 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i ], ) def test_text_chat_completion_tool_calling_tools_not_in_request( - client_with_models, text_model_id, test_case, streaming + client_with_text_models, text_model_id, test_case, streaming ): tc = TestCase(test_case) @@ -396,7 +400,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( "stream": streaming, } - response = client_with_models.inference.chat_completion(**request) + response = client_with_text_models.inference.chat_completion(**request) if streaming: for chunk in response: diff --git a/tests/client-sdk/inference/test_vision_inference.py b/tests/client-sdk/inference/test_vision_inference.py index 8fa0d8023..274df03f2 100644 --- a/tests/client-sdk/inference/test_vision_inference.py +++ b/tests/client-sdk/inference/test_vision_inference.py @@ -27,7 +27,7 @@ def base64_image_url(base64_image_data, image_path): return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" -def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): +def test_image_chat_completion_non_streaming(client_with_vision_models, vision_model_id): message = { "role": "user", "content": [ @@ -45,7 +45,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id }, ], } - response = client_with_models.inference.chat_completion( + response = client_with_vision_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False, @@ -55,7 +55,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) -def test_image_chat_completion_streaming(client_with_models, vision_model_id): +def test_image_chat_completion_streaming(client_with_vision_models, vision_model_id): message = { "role": "user", "content": [ @@ -73,7 +73,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id): }, ], } - response = client_with_models.inference.chat_completion( + response = client_with_vision_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=True, @@ -86,7 +86,9 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id): @pytest.mark.parametrize("type_", ["url", "data"]) -def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): +def test_image_chat_completion_base64( + client_with_vision_models, vision_model_id, base64_image_data, base64_image_url, type_ +): image_spec = { "url": { "type": "image", @@ -114,7 +116,7 @@ def test_image_chat_completion_base64(client_with_models, vision_model_id, base6 }, ], } - response = client_with_models.inference.chat_completion( + response = client_with_vision_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False,