forked from phoenix-oss/llama-stack-mirror
		
	feat: completing text /chat-completion and /completion tests (#1223)
# What does this PR do? The goal is to have a fairly complete set of provider and e2e tests for /chat-completion and /completion. This is the current list, ``` grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_model_list - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_logprobs_non_streaming - test_text_completion_logprobs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_structured_output - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling - test_text_chat_completion_with_tool_calling_streaming ``` grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_log_probs_non_streaming - test_text_completion_log_probs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling_and_non_streaming - test_text_chat_completion_with_tool_calling_and_streaming - test_text_chat_completion_with_tool_choice_required - test_text_chat_completion_with_tool_choice_none - test_text_chat_completion_structured_output - test_text_chat_completion_tool_calling_tools_not_in_request ## Test plan == Set up Ollama local server ``` OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m ``` == Run a provider test ``` conda activate stack OLLAMA_URL="http://localhost:8321" \ pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \ llama_stack/providers/tests/inference/test_text_inference.py::TestInference ``` == Run an e2e test ``` conda activate sherpa with-proxy pip install llama-stack export INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=8322 with-proxy llama stack build --template ollama with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama ``` ``` conda activate stack LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \ pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \ tests/client-sdk/inference/test_text_inference.py ```
This commit is contained in:
		
							parent
							
								
									9b130f96a7
								
							
						
					
					
						commit
						3a31611486
					
				
					 8 changed files with 479 additions and 223 deletions
				
			
		|  | @ -28,23 +28,17 @@ def provider_tool_format(inference_provider_type): | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def get_weather_tool_definition(): | ||||
|     return { | ||||
|         "tool_name": "get_weather", | ||||
|         "description": "Get the current weather", | ||||
|         "parameters": { | ||||
|             "location": { | ||||
|                 "param_type": "string", | ||||
|                 "description": "The city and state, e.g. San Francisco, CA", | ||||
|             }, | ||||
|         }, | ||||
|     } | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:completion:sanity", | ||||
|     ], | ||||
| ) | ||||
| def test_text_completion_non_streaming(client_with_models, text_model_id, test_case): | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
| 
 | ||||
| def test_text_completion_non_streaming(client_with_models, text_model_id): | ||||
|     response = client_with_models.inference.completion( | ||||
|         content="Complete the sentence using one word: Roses are red, violets are ", | ||||
|         content=tc["content"], | ||||
|         stream=False, | ||||
|         model_id=text_model_id, | ||||
|         sampling_params={ | ||||
|  | @ -55,9 +49,17 @@ def test_text_completion_non_streaming(client_with_models, text_model_id): | |||
|     # assert "blue" in response.content.lower().strip() | ||||
| 
 | ||||
| 
 | ||||
| def test_text_completion_streaming(client_with_models, text_model_id): | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:completion:sanity", | ||||
|     ], | ||||
| ) | ||||
| def test_text_completion_streaming(client_with_models, text_model_id, test_case): | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.completion( | ||||
|         content="Complete the sentence using one word: Roses are red, violets are ", | ||||
|         content=tc["content"], | ||||
|         stream=True, | ||||
|         model_id=text_model_id, | ||||
|         sampling_params={ | ||||
|  | @ -70,12 +72,20 @@ def test_text_completion_streaming(client_with_models, text_model_id): | |||
|     assert len(content_str) > 10 | ||||
| 
 | ||||
| 
 | ||||
| def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type): | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:completion:log_probs", | ||||
|     ], | ||||
| ) | ||||
| def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case): | ||||
|     if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: | ||||
|         pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") | ||||
| 
 | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.completion( | ||||
|         content="Complete the sentence: Micheael Jordan is born in ", | ||||
|         content=tc["content"], | ||||
|         stream=False, | ||||
|         model_id=text_model_id, | ||||
|         sampling_params={ | ||||
|  | @ -90,12 +100,20 @@ def test_completion_log_probs_non_streaming(client_with_models, text_model_id, i | |||
|     assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) | ||||
| 
 | ||||
| 
 | ||||
| def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type): | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:completion:log_probs", | ||||
|     ], | ||||
| ) | ||||
| def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case): | ||||
|     if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: | ||||
|         pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") | ||||
| 
 | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.completion( | ||||
|         content="Complete the sentence: Micheael Jordan is born in ", | ||||
|         content=tc["content"], | ||||
|         stream=True, | ||||
|         model_id=text_model_id, | ||||
|         sampling_params={ | ||||
|  | @ -114,7 +132,12 @@ def test_completion_log_probs_streaming(client_with_models, text_model_id, infer | |||
|             assert not chunk.logprobs, "Logprobs should be empty" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("test_case", ["completion-01"]) | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:completion:structured_output", | ||||
|     ], | ||||
| ) | ||||
| def test_text_completion_structured_output(client_with_models, text_model_id, test_case): | ||||
|     class AnswerFormat(BaseModel): | ||||
|         name: str | ||||
|  | @ -144,16 +167,17 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "question,expected", | ||||
|     "test_case", | ||||
|     [ | ||||
|         ("Which planet do humans live on?", "Earth"), | ||||
|         ( | ||||
|             "Which planet has rings around it with a name starting with letter S?", | ||||
|             "Saturn", | ||||
|         ), | ||||
|         "inference:chat_completion:non_streaming_01", | ||||
|         "inference:chat_completion:non_streaming_02", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected): | ||||
| def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case): | ||||
|     tc = TestCase(test_case) | ||||
|     question = tc["question"] | ||||
|     expected = tc["expected"] | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[ | ||||
|  | @ -170,13 +194,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, q | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "question,expected", | ||||
|     "test_case", | ||||
|     [ | ||||
|         ("What's the name of the Sun in latin?", "Sol"), | ||||
|         ("What is the name of the US captial?", "Washington"), | ||||
|         "inference:chat_completion:streaming_01", | ||||
|         "inference:chat_completion:streaming_02", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected): | ||||
| def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case): | ||||
|     tc = TestCase(test_case) | ||||
|     question = tc["question"] | ||||
|     expected = tc["expected"] | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[{"role": "user", "content": question}], | ||||
|  | @ -187,18 +215,26 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, quest | |||
|     assert expected.lower() in "".join(streamed_content) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:chat_completion:tool_calling", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_with_tool_calling_and_non_streaming( | ||||
|     client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format | ||||
|     client_with_models, text_model_id, provider_tool_format, test_case | ||||
| ): | ||||
|     # TODO: more dynamic lookup on tool_prompt_format for model family | ||||
|     tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" | ||||
| 
 | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             {"role": "user", "content": "What's the weather like in San Francisco?"}, | ||||
|         ], | ||||
|         tools=[get_weather_tool_definition], | ||||
|         messages=tc["messages"], | ||||
|         tools=tc["tools"], | ||||
|         tool_choice="auto", | ||||
|         tool_prompt_format=provider_tool_format, | ||||
|         tool_prompt_format=tool_prompt_format, | ||||
|         stream=False, | ||||
|     ) | ||||
|     # No content is returned for the system message since we expect the | ||||
|  | @ -207,8 +243,8 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( | |||
|     assert response.completion_message.role == "assistant" | ||||
| 
 | ||||
|     assert len(response.completion_message.tool_calls) == 1 | ||||
|     assert response.completion_message.tool_calls[0].tool_name == "get_weather" | ||||
|     assert response.completion_message.tool_calls[0].arguments == {"location": "San Francisco, CA"} | ||||
|     assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] | ||||
|     assert response.completion_message.tool_calls[0].arguments == tc["expected"] | ||||
| 
 | ||||
| 
 | ||||
| # Will extract streamed text and separate it from tool invocation content | ||||
|  | @ -224,57 +260,80 @@ def extract_tool_invocation_content(response): | |||
|     return tool_invocation_content | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:chat_completion:tool_calling", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_with_tool_calling_and_streaming( | ||||
|     client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format | ||||
|     client_with_models, text_model_id, provider_tool_format, test_case | ||||
| ): | ||||
|     # TODO: more dynamic lookup on tool_prompt_format for model family | ||||
|     tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" | ||||
| 
 | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             {"role": "user", "content": "What's the weather like in San Francisco?"}, | ||||
|         ], | ||||
|         tools=[get_weather_tool_definition], | ||||
|         messages=tc["messages"], | ||||
|         tools=tc["tools"], | ||||
|         tool_choice="auto", | ||||
|         tool_prompt_format=provider_tool_format, | ||||
|         tool_prompt_format=tool_prompt_format, | ||||
|         stream=True, | ||||
|     ) | ||||
|     tool_invocation_content = extract_tool_invocation_content(response) | ||||
|     assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" | ||||
|     expected_tool_name = tc["tools"][0]["tool_name"] | ||||
|     expected_argument = tc["expected"] | ||||
|     assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:chat_completion:tool_calling", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_with_tool_choice_required( | ||||
|     client_with_models, | ||||
|     text_model_id, | ||||
|     get_weather_tool_definition, | ||||
|     provider_tool_format, | ||||
|     test_case, | ||||
| ): | ||||
|     # TODO: more dynamic lookup on tool_prompt_format for model family | ||||
|     tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" | ||||
| 
 | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             {"role": "user", "content": "What's the weather like in San Francisco?"}, | ||||
|         ], | ||||
|         tools=[get_weather_tool_definition], | ||||
|         messages=tc["messages"], | ||||
|         tools=tc["tools"], | ||||
|         tool_config={ | ||||
|             "tool_choice": "required", | ||||
|             "tool_prompt_format": provider_tool_format, | ||||
|             "tool_prompt_format": tool_prompt_format, | ||||
|         }, | ||||
|         stream=True, | ||||
|     ) | ||||
|     tool_invocation_content = extract_tool_invocation_content(response) | ||||
|     assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" | ||||
|     expected_tool_name = tc["tools"][0]["tool_name"] | ||||
|     expected_argument = tc["expected"] | ||||
|     assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" | ||||
| 
 | ||||
| 
 | ||||
| def test_text_chat_completion_with_tool_choice_none( | ||||
|     client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format | ||||
| ): | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:chat_completion:tool_calling", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case): | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     response = client_with_models.inference.chat_completion( | ||||
|         model_id=text_model_id, | ||||
|         messages=[ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             {"role": "user", "content": "What's the weather like in San Francisco?"}, | ||||
|         ], | ||||
|         tools=[get_weather_tool_definition], | ||||
|         messages=tc["messages"], | ||||
|         tools=tc["tools"], | ||||
|         tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, | ||||
|         stream=True, | ||||
|     ) | ||||
|  | @ -282,7 +341,12 @@ def test_text_chat_completion_with_tool_choice_none( | |||
|     assert tool_invocation_content == "" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("test_case", ["chat_completion-01"]) | ||||
| @pytest.mark.parametrize( | ||||
|     "test_case", | ||||
|     [ | ||||
|         "inference:chat_completion:structured_output", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): | ||||
|     class AnswerFormat(BaseModel): | ||||
|         first_name: str | ||||
|  | @ -309,64 +373,24 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i | |||
|     assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("streaming", [True, False]) | ||||
| @pytest.mark.parametrize( | ||||
|     "streaming", | ||||
|     "test_case", | ||||
|     [ | ||||
|         True, | ||||
|         False, | ||||
|         "inference:chat_completion:tool_calling_tools_absent", | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming): | ||||
| def test_text_chat_completion_tool_calling_tools_not_in_request( | ||||
|     client_with_models, text_model_id, test_case, streaming | ||||
| ): | ||||
|     tc = TestCase(test_case) | ||||
| 
 | ||||
|     # TODO: more dynamic lookup on tool_prompt_format for model family | ||||
|     tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" | ||||
|     request = { | ||||
|         "model_id": text_model_id, | ||||
|         "messages": [ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             { | ||||
|                 "role": "user", | ||||
|                 "content": "What pods are in the namespace openshift-lightspeed?", | ||||
|             }, | ||||
|             { | ||||
|                 "role": "assistant", | ||||
|                 "content": "", | ||||
|                 "stop_reason": "end_of_turn", | ||||
|                 "tool_calls": [ | ||||
|                     { | ||||
|                         "call_id": "1", | ||||
|                         "tool_name": "get_object_namespace_list", | ||||
|                         "arguments": { | ||||
|                             "kind": "pod", | ||||
|                             "namespace": "openshift-lightspeed", | ||||
|                         }, | ||||
|                     } | ||||
|                 ], | ||||
|             }, | ||||
|             { | ||||
|                 "role": "tool", | ||||
|                 "call_id": "1", | ||||
|                 "tool_name": "get_object_namespace_list", | ||||
|                 "content": "the objects are pod1, pod2, pod3", | ||||
|             }, | ||||
|         ], | ||||
|         "tools": [ | ||||
|             { | ||||
|                 "tool_name": "get_object_namespace_list", | ||||
|                 "description": "Get the list of objects in a namespace", | ||||
|                 "parameters": { | ||||
|                     "kind": { | ||||
|                         "param_type": "string", | ||||
|                         "description": "the type of object", | ||||
|                         "required": True, | ||||
|                     }, | ||||
|                     "namespace": { | ||||
|                         "param_type": "string", | ||||
|                         "description": "the name of the namespace", | ||||
|                         "required": True, | ||||
|                     }, | ||||
|                 }, | ||||
|             } | ||||
|         ], | ||||
|         "messages": tc["messages"], | ||||
|         "tools": tc["tools"], | ||||
|         "tool_choice": "auto", | ||||
|         "tool_prompt_format": tool_prompt_format, | ||||
|         "stream": streaming, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue