From 67450e4024fe82709bbfbb0f90734fbc55d7941f Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 15 Jan 2025 15:39:05 -0800 Subject: [PATCH 1/2] bug fixes on inference tests (#774) # What does this PR do? Fixes two issues on providers/test/inference - [ ] Addresses issue (#issue) ## Test Plan ### Before ``` ===================================================================================== FAILURES ===================================================================================== __________________________________ TestVisionModelInference.test_vision_chat_completion_streaming[llama_vision-fireworks][llama_vision] ___________________________________ providers/tests/inference/test_vision_inference.py:145: in test_vision_chat_completion_streaming content = "".join( E TypeError: sequence item 0: expected str instance, TextDelta found ------------------------------------------------------------------------------ Captured log teardown ------------------------------------------------------------------------------- ERROR asyncio:base_events.py:1858 Task was destroyed but it is pending! task: ()>> ============================================================================= short test summary info ============================================================================== FAILED providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[llama_vision-fireworks] - TypeError: sequence item 0: expected str instance, TextDelta found ============================================================== 1 failed, 2 passed, 33 deselected, 7 warnings in 3.59s ============================================================== (base) sxyi@sxyi-mbp llama_stack % ``` ### After ``` (base) sxyi@sxyi-mbp llama_stack % pytest -k "fireworks" /Users/sxyi/llama-stack/llama_stack/providers/tests/inference/test_vision_inference.py /Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =============================================================================== test session starts ================================================================================ platform darwin -- Python 3.13.0, pytest-8.3.3, pluggy-1.5.0 rootdir: /Users/sxyi/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, dependency-0.6.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 36 items / 33 deselected / 3 selected providers/tests/inference/test_vision_inference.py ... [100%] =================================================================== 3 passed, 33 deselected, 7 warnings in 3.75s =================================================================== (base) sxyi@sxyi-mbp llama_stack % ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/providers/tests/inference/test_embeddings.py | 3 ++- llama_stack/providers/tests/inference/test_vision_inference.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py index bf09896c1..ca0276ed6 100644 --- a/llama_stack/providers/tests/inference/test_embeddings.py +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -6,7 +6,8 @@ import pytest -from llama_stack.apis.inference import EmbeddingsResponse, ModelType +from llama_stack.apis.inference import EmbeddingsResponse +from llama_stack.apis.models import ModelType # How to run this test: # pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 1bdee051f..df2f3cfb9 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -143,7 +143,7 @@ class TestVisionModelInference: assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 content = "".join( - chunk.event.delta + chunk.event.delta.text for chunk in grouped[ChatCompletionResponseEventType.progress] ) for expected_string in expected_strings: From 3e518c049a2e2410d4500cb19c0d3e5d51f49f08 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 15 Jan 2025 15:52:26 -0800 Subject: [PATCH 2/2] [bugfix] fix inference sdk test for v1 (#775) # What does this PR do? - fixes client sdk tests ## Test Plan ``` LLAMA_STACK_BASE_URL="http://localhost:5000" pytest -v tests/client-sdk/inference/test_inference.py ``` image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- tests/client-sdk/inference/test_inference.py | 38 +++++++++----------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index a50dba3a0..999f3f0e5 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -28,19 +28,18 @@ def provider_tool_format(inference_provider_type): @pytest.fixture(scope="session") def inference_provider_type(llama_stack_client): providers = llama_stack_client.providers.list() - if "inference" not in providers: - pytest.fail("No inference providers available") - assert len(providers["inference"]) > 0 - return providers["inference"][0].provider_type + assert len(providers.inference) > 0 + return providers.inference[0]["provider_type"] @pytest.fixture(scope="session") def text_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list() + for model in llama_stack_client.models.list().data if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] + print(available_models) assert len(available_models) > 0 return available_models[0] @@ -49,7 +48,7 @@ def text_model_id(llama_stack_client): def vision_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list() + for model in llama_stack_client.models.list().data if "vision" in model.identifier.lower() ] if len(available_models) == 0: @@ -245,19 +244,13 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( # The returned tool inovcation content will be a string so it's easy to comapare with expected value # e.g. "[get_weather, {'location': 'San Francisco, CA'}]" def extract_tool_invocation_content(response): - text_content: str = "" tool_invocation_content: str = "" for chunk in response: delta = chunk.event.delta - if delta.type == "text": - text_content += delta.text - elif delta.type == "tool_call": - if isinstance(delta.content, str): - tool_invocation_content += delta.content - else: - call = delta.content - tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" - return text_content, tool_invocation_content + if delta.type == "tool_call" and delta.parse_status == "succeeded": + call = delta.content + tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" + return tool_invocation_content def test_text_chat_completion_with_tool_calling_and_streaming( @@ -274,8 +267,11 @@ def test_text_chat_completion_with_tool_calling_and_streaming( tool_prompt_format=provider_tool_format, stream=True, ) - text_content, tool_invocation_content = extract_tool_invocation_content(response) - + tool_invocation_content = extract_tool_invocation_content(response) + print( + "!!!!tool_invocation_content", + tool_invocation_content, + ) assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" @@ -362,8 +358,8 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): messages=[message], stream=True, ) - streamed_content = [ - str(chunk.event.delta.text.lower().strip()) for chunk in response - ] + streamed_content = "" + for chunk in response: + streamed_content += chunk.event.delta.text.lower() assert len(streamed_content) > 0 assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})