From d7cc38e93424b9d4610b139889c9d8e8d4ee2352 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 18 Jul 2025 00:35:28 -0400 Subject: [PATCH] fix: remove async test markers (fix pre-commit) (#2808) # What does this PR do? some async test markers are in the codebase causing pre-commit to fail due to #2744 remove these pytest fixtures ## Test Plan pre-commit passes Signed-off-by: Charlie Doern --- tests/integration/post_training/test_post_training.py | 6 +++--- tests/unit/models/test_prompt_adapter.py | 10 ---------- tests/unit/providers/vector_io/remote/test_milvus.py | 10 ++-------- tests/unit/rag/test_rag_query.py | 1 - 4 files changed, 5 insertions(+), 22 deletions(-) diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index bb4639d17..3d56b322f 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -123,14 +123,14 @@ class TestPostTraining: logger.info(f"Job artifacts: {artifacts}") # TODO: Fix these tests to properly represent the Jobs API in training - # @pytest.mark.asyncio + # # async def test_get_training_jobs(self, post_training_stack): # post_training_impl = post_training_stack # jobs_list = await post_training_impl.get_training_jobs() # assert isinstance(jobs_list, list) # assert jobs_list[0].job_uuid == "1234" - # @pytest.mark.asyncio + # # async def test_get_training_job_status(self, post_training_stack): # post_training_impl = post_training_stack # job_status = await post_training_impl.get_training_job_status("1234") @@ -139,7 +139,7 @@ class TestPostTraining: # assert job_status.status == JobStatus.completed # assert isinstance(job_status.checkpoints[0], Checkpoint) - # @pytest.mark.asyncio + # # async def test_get_training_job_artifacts(self, post_training_stack): # post_training_impl = post_training_stack # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 577496cec..0362eb5dd 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -32,7 +31,6 @@ MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -@pytest.mark.asyncio async def test_system_default(): content = "Hello !" request = ChatCompletionRequest( @@ -47,7 +45,6 @@ async def test_system_default(): assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) -@pytest.mark.asyncio async def test_system_builtin_only(): content = "Hello !" request = ChatCompletionRequest( @@ -67,7 +64,6 @@ async def test_system_builtin_only(): assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) -@pytest.mark.asyncio async def test_system_custom_only(): content = "Hello !" request = ChatCompletionRequest( @@ -98,7 +94,6 @@ async def test_system_custom_only(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_system_custom_and_builtin(): content = "Hello !" request = ChatCompletionRequest( @@ -132,7 +127,6 @@ async def test_system_custom_and_builtin(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_completion_message_encoding(): request = ChatCompletionRequest( model=MODEL3_2, @@ -174,7 +168,6 @@ async def test_completion_message_encoding(): assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt -@pytest.mark.asyncio async def test_user_provided_system_message(): content = "Hello !" system_prompt = "You are a pirate" @@ -195,7 +188,6 @@ async def test_user_provided_system_message(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_builtin_tools(): content = "Hello !" system_prompt = "You are a pirate" @@ -221,7 +213,6 @@ async def test_replace_system_message_behavior_builtin_tools(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_custom_tools(): content = "Hello !" system_prompt = "You are a pirate" @@ -259,7 +250,6 @@ async def test_replace_system_message_behavior_custom_tools(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_custom_tools_with_template(): content = "Hello !" system_prompt = "You are a pirate {{ function_description }}" diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 2f212e374..145edf7fb 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest -import pytest_asyncio from llama_stack.apis.vector_io import QueryChunksResponse @@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): MILVUS_PROVIDER = "milvus" -@pytest_asyncio.fixture +@pytest.fixture async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() @@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock: return client -@pytest_asyncio.fixture +@pytest.fixture async def milvus_index(mock_milvus_client): """Create a MilvusIndex with mocked client.""" index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") @@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client): # No real cleanup needed since we're using mocks -@pytest.mark.asyncio async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): # Setup: collection doesn't exist initially, then exists after creation mock_milvus_client.has_collection.side_effect = [False, True] @@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m assert len(insert_call[1]["data"]) == len(sample_chunks) -@pytest.mark.asyncio async def test_query_chunks_vector( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): @@ -125,7 +122,6 @@ async def test_query_chunks_vector( mock_milvus_client.search.assert_called_once() -@pytest.mark.asyncio async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): mock_milvus_client.has_collection.return_value = True await milvus_index.add_chunks(sample_chunks, sample_embeddings) @@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e assert len(response.chunks) == 2 -@pytest.mark.asyncio async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): """Test that when BM25 search fails, the system falls back to simple text search.""" mock_milvus_client.has_collection.return_value = True @@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" -@pytest.mark.asyncio async def test_delete_collection(milvus_index, mock_milvus_client): # Test collection deletion mock_milvus_client.has_collection.return_value = True diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index ad155c205..a9149541a 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -64,7 +64,6 @@ class TestRagQuery: with pytest.raises(ValueError): RAGQueryConfig(mode="invalid_mode") - @pytest.mark.asyncio async def test_query_accepts_valid_modes(self): RAGQueryConfig() # Test default (vector) RAGQueryConfig(mode="vector") # Test vector