From 3e518c049a2e2410d4500cb19c0d3e5d51f49f08 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 15 Jan 2025 15:52:26 -0800 Subject: [PATCH] [bugfix] fix inference sdk test for v1 (#775) # What does this PR do? - fixes client sdk tests ## Test Plan ``` LLAMA_STACK_BASE_URL="http://localhost:5000" pytest -v tests/client-sdk/inference/test_inference.py ``` image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- tests/client-sdk/inference/test_inference.py | 38 +++++++++----------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index a50dba3a0..999f3f0e5 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -28,19 +28,18 @@ def provider_tool_format(inference_provider_type): @pytest.fixture(scope="session") def inference_provider_type(llama_stack_client): providers = llama_stack_client.providers.list() - if "inference" not in providers: - pytest.fail("No inference providers available") - assert len(providers["inference"]) > 0 - return providers["inference"][0].provider_type + assert len(providers.inference) > 0 + return providers.inference[0]["provider_type"] @pytest.fixture(scope="session") def text_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list() + for model in llama_stack_client.models.list().data if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] + print(available_models) assert len(available_models) > 0 return available_models[0] @@ -49,7 +48,7 @@ def text_model_id(llama_stack_client): def vision_model_id(llama_stack_client): available_models = [ model.identifier - for model in llama_stack_client.models.list() + for model in llama_stack_client.models.list().data if "vision" in model.identifier.lower() ] if len(available_models) == 0: @@ -245,19 +244,13 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( # The returned tool inovcation content will be a string so it's easy to comapare with expected value # e.g. "[get_weather, {'location': 'San Francisco, CA'}]" def extract_tool_invocation_content(response): - text_content: str = "" tool_invocation_content: str = "" for chunk in response: delta = chunk.event.delta - if delta.type == "text": - text_content += delta.text - elif delta.type == "tool_call": - if isinstance(delta.content, str): - tool_invocation_content += delta.content - else: - call = delta.content - tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" - return text_content, tool_invocation_content + if delta.type == "tool_call" and delta.parse_status == "succeeded": + call = delta.content + tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" + return tool_invocation_content def test_text_chat_completion_with_tool_calling_and_streaming( @@ -274,8 +267,11 @@ def test_text_chat_completion_with_tool_calling_and_streaming( tool_prompt_format=provider_tool_format, stream=True, ) - text_content, tool_invocation_content = extract_tool_invocation_content(response) - + tool_invocation_content = extract_tool_invocation_content(response) + print( + "!!!!tool_invocation_content", + tool_invocation_content, + ) assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" @@ -362,8 +358,8 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): messages=[message], stream=True, ) - streamed_content = [ - str(chunk.event.delta.text.lower().strip()) for chunk in response - ] + streamed_content = "" + for chunk in response: + streamed_content += chunk.event.delta.text.lower() assert len(streamed_content) > 0 assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})