forked from phoenix-oss/llama-stack-mirror
		
	perf: ensure ToolCall in ChatCompletionResponse is subset of ChatCompletionRequest.tools (#1041)
# What does this PR do?
**Problem**
- Using script:
https://gist.github.com/thoraxe/6163b2145ce7b1c24c6026b64cf90085
- This hits an issue on server with `code_interpreter` not found, as we
do not pass "builtin::code_interpreter" in AgentConfig's `toolgroups`.
This is a general issue where model always tries to output
`code_interpreter` in `ToolCall` even when we do not have
`code_interpreter` available for execution.
**Reproduce Deeper Problem in chat-completion**
- Use script:
https://gist.github.com/yanxi0830/163a9ad7b5db10556043fbfc7ecd7603
1. We currently always populate `code_interpreter` in `ToolCall` in
ChatCompletionResponse if the model's response begins with
`<|python_tag|>`. See
c5f5958498/models/llama3/api/chat_format.py (L200-L213)
<img width="913" alt="image"
src="https://github.com/user-attachments/assets/328d313d-0a0b-495c-8715-61cca9ccc4a6"
/>
2. This happens even if we do not pass the `code_interpreter` as a
`tools` in ChatCompletionRequest.
**This PR**
Explicitly make sure that the tools returned in
`ChatCompletionResponse.tool_calls` is always a tool requested by
`ChatCompletionRequest.tools`.
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
## Test Plan
**Before**
<img width="913" alt="image"
src="https://github.com/user-attachments/assets/328d313d-0a0b-495c-8715-61cca9ccc4a6"
/>
<img width="997" alt="image"
src="https://github.com/user-attachments/assets/d3e82b62-b142-4939-954c-62843bec7110"
/>
**After**
<img width="856" alt="image"
src="https://github.com/user-attachments/assets/2c70ce55-c8d0-45ea-b10f-f70adc50d3d9"
/>
<img width="1000" alt="image"
src="https://github.com/user-attachments/assets/b5e81826-c35b-4052-bf81-7afff93ce2ef"
/>
**Unit Test**
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request --inference-model "meta-llama/Llama-3.3-70B-Instruct"
```
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/
```
<img width="1002" alt="image"
src="https://github.com/user-attachments/assets/04808517-eded-4122-97f5-7e5142de9779"
/>
**Streaming**
- Chat Completion
<img width="902" alt="image"
src="https://github.com/user-attachments/assets/f477bc86-bd38-4729-b49e-a0a6ed3f835a"
/>
- Agent
<img width="916" alt="image"
src="https://github.com/user-attachments/assets/f4cc3417-23cd-46b1-953d-3a2271e79bbb"
/>
[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)
			
			
This commit is contained in:
		
							parent
							
								
									dd37e58868
								
							
						
					
					
						commit
						66d7e15c93
					
				
					 14 changed files with 164 additions and 33 deletions
				
			
		|  | @ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in | |||
|     "question,expected", | ||||
|     [ | ||||
|         ("Which planet do humans live on?", "Earth"), | ||||
|         ("Which planet has rings around it with a name starting with letter S?", "Saturn"), | ||||
|         ( | ||||
|             "Which planet has rings around it with a name starting with letter S?", | ||||
|             "Saturn", | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): | ||||
|  | @ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i | |||
|     assert answer.last_name == "Jordan" | ||||
|     assert answer.year_of_birth == 1963 | ||||
|     assert answer.num_seasons_in_nba == 15 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "streaming", | ||||
|     [ | ||||
|         True, | ||||
|         False, | ||||
|     ], | ||||
| ) | ||||
| def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): | ||||
|     # TODO: more dynamic lookup on tool_prompt_format for model family | ||||
|     tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" | ||||
|     request = { | ||||
|         "model_id": text_model_id, | ||||
|         "messages": [ | ||||
|             {"role": "system", "content": "You are a helpful assistant."}, | ||||
|             { | ||||
|                 "role": "user", | ||||
|                 "content": "What pods are in the namespace openshift-lightspeed?", | ||||
|             }, | ||||
|             { | ||||
|                 "role": "assistant", | ||||
|                 "content": "", | ||||
|                 "stop_reason": "end_of_turn", | ||||
|                 "tool_calls": [ | ||||
|                     { | ||||
|                         "call_id": "1", | ||||
|                         "tool_name": "get_object_namespace_list", | ||||
|                         "arguments": { | ||||
|                             "kind": "pod", | ||||
|                             "namespace": "openshift-lightspeed", | ||||
|                         }, | ||||
|                     } | ||||
|                 ], | ||||
|             }, | ||||
|             { | ||||
|                 "role": "tool", | ||||
|                 "call_id": "1", | ||||
|                 "tool_name": "get_object_namespace_list", | ||||
|                 "content": "the objects are pod1, pod2, pod3", | ||||
|             }, | ||||
|         ], | ||||
|         "tools": [ | ||||
|             { | ||||
|                 "tool_name": "get_object_namespace_list", | ||||
|                 "description": "Get the list of objects in a namespace", | ||||
|                 "parameters": { | ||||
|                     "kind": { | ||||
|                         "param_type": "string", | ||||
|                         "description": "the type of object", | ||||
|                         "required": True, | ||||
|                     }, | ||||
|                     "namespace": { | ||||
|                         "param_type": "string", | ||||
|                         "description": "the name of the namespace", | ||||
|                         "required": True, | ||||
|                     }, | ||||
|                 }, | ||||
|             } | ||||
|         ], | ||||
|         "tool_choice": "auto", | ||||
|         "tool_prompt_format": tool_prompt_format, | ||||
|         "stream": streaming, | ||||
|     } | ||||
| 
 | ||||
|     response = llama_stack_client.inference.chat_completion(**request) | ||||
| 
 | ||||
|     if streaming: | ||||
|         for chunk in response: | ||||
|             delta = chunk.event.delta | ||||
|             if delta.type == "tool_call" and delta.parse_status == "succeeded": | ||||
|                 assert delta.tool_call.tool_name == "get_object_namespace_list" | ||||
|             if delta.type == "tool_call" and delta.parse_status == "failed": | ||||
|                 # expect raw message that failed to parse in tool_call | ||||
|                 assert type(delta.tool_call) == str | ||||
|                 assert len(delta.tool_call) > 0 | ||||
|     else: | ||||
|         for tc in response.completion_message.tool_calls: | ||||
|             assert tc.tool_name == "get_object_namespace_list" | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue