forked from phoenix-oss/llama-stack-mirror
perf: ensure ToolCall in ChatCompletionResponse is subset of ChatCompletionRequest.tools (#1041)
# What does this PR do?
**Problem**
- Using script:
https://gist.github.com/thoraxe/6163b2145ce7b1c24c6026b64cf90085
- This hits an issue on server with `code_interpreter` not found, as we
do not pass "builtin::code_interpreter" in AgentConfig's `toolgroups`.
This is a general issue where model always tries to output
`code_interpreter` in `ToolCall` even when we do not have
`code_interpreter` available for execution.
**Reproduce Deeper Problem in chat-completion**
- Use script:
https://gist.github.com/yanxi0830/163a9ad7b5db10556043fbfc7ecd7603
1. We currently always populate `code_interpreter` in `ToolCall` in
ChatCompletionResponse if the model's response begins with
`<|python_tag|>`. See
c5f5958498/models/llama3/api/chat_format.py (L200-L213)
<img width="913" alt="image"
src="https://github.com/user-attachments/assets/328d313d-0a0b-495c-8715-61cca9ccc4a6"
/>
2. This happens even if we do not pass the `code_interpreter` as a
`tools` in ChatCompletionRequest.
**This PR**
Explicitly make sure that the tools returned in
`ChatCompletionResponse.tool_calls` is always a tool requested by
`ChatCompletionRequest.tools`.
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
## Test Plan
**Before**
<img width="913" alt="image"
src="https://github.com/user-attachments/assets/328d313d-0a0b-495c-8715-61cca9ccc4a6"
/>
<img width="997" alt="image"
src="https://github.com/user-attachments/assets/d3e82b62-b142-4939-954c-62843bec7110"
/>
**After**
<img width="856" alt="image"
src="https://github.com/user-attachments/assets/2c70ce55-c8d0-45ea-b10f-f70adc50d3d9"
/>
<img width="1000" alt="image"
src="https://github.com/user-attachments/assets/b5e81826-c35b-4052-bf81-7afff93ce2ef"
/>
**Unit Test**
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request --inference-model "meta-llama/Llama-3.3-70B-Instruct"
```
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/
```
<img width="1002" alt="image"
src="https://github.com/user-attachments/assets/04808517-eded-4122-97f5-7e5142de9779"
/>
**Streaming**
- Chat Completion
<img width="902" alt="image"
src="https://github.com/user-attachments/assets/f477bc86-bd38-4729-b49e-a0a6ed3f835a"
/>
- Agent
<img width="916" alt="image"
src="https://github.com/user-attachments/assets/f4cc3417-23cd-46b1-953d-3a2271e79bbb"
/>
[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)
This commit is contained in:
parent
dd37e58868
commit
66d7e15c93
14 changed files with 164 additions and 33 deletions
|
@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_calls.append(delta.tool_call)
|
||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||
content = delta.tool_call
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
|
@ -201,7 +201,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
|
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
|
|
|
@ -134,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
|
@ -152,7 +152,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||
|
|
|
@ -155,14 +155,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
|
|
@ -112,7 +112,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -123,7 +123,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
|
|
|
@ -230,7 +230,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
|
|
@ -304,7 +304,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -330,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -99,7 +99,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -110,7 +110,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
|
|
|
@ -160,7 +160,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -236,7 +236,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -252,7 +252,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
|
|
|
@ -220,7 +220,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -235,7 +235,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
|
|
@ -232,7 +232,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import (
|
||||
|
@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -41,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
convert_image_content_to_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
content: str
|
||||
|
@ -170,7 +173,9 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
|
|||
|
||||
|
||||
def process_chat_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
|
@ -179,6 +184,28 @@ def process_chat_completion_response(
|
|||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
# response from the model.
|
||||
if raw_message.tool_calls:
|
||||
if not request.tools:
|
||||
raw_message.tool_calls = []
|
||||
raw_message.content = text_from_choice(choice)
|
||||
else:
|
||||
# only return tool_calls if provided in the request
|
||||
new_tool_calls = []
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
for t in raw_message.tool_calls:
|
||||
if t.tool_name in request_tools:
|
||||
new_tool_calls.append(t)
|
||||
else:
|
||||
logger.warning(f"Tool {t.tool_name} not found in request tools")
|
||||
|
||||
if len(new_tool_calls) < len(raw_message.tool_calls):
|
||||
raw_message.tool_calls = new_tool_calls
|
||||
raw_message.content = text_from_choice(choice)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
|
@ -226,7 +253,9 @@ async def process_completion_stream_response(
|
|||
|
||||
|
||||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -305,6 +334,7 @@ async def process_chat_completion_stream_response(
|
|||
|
||||
# parse tool calls and report errors
|
||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -318,17 +348,33 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
)
|
||||
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
if tool_call.tool_name in request_tools:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
# Parsing tool call failed due to tool call not being found in request tools,
|
||||
# We still add the raw message text inside tool_call for responding back to the user
|
||||
tool_call=buffer,
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
"question,expected",
|
||||
[
|
||||
("Which planet do humans live on?", "Earth"),
|
||||
("Which planet has rings around it with a name starting with letter S?", "Saturn"),
|
||||
(
|
||||
"Which planet has rings around it with a name starting with letter S?",
|
||||
"Saturn",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
|
||||
|
@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"streaming",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
request = {
|
||||
"model_id": text_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What pods are in the namespace openshift-lightspeed?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "1",
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"arguments": {
|
||||
"kind": "pod",
|
||||
"namespace": "openshift-lightspeed",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"call_id": "1",
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"content": "the objects are pod1, pod2, pod3",
|
||||
},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"description": "Get the list of objects in a namespace",
|
||||
"parameters": {
|
||||
"kind": {
|
||||
"param_type": "string",
|
||||
"description": "the type of object",
|
||||
"required": True,
|
||||
},
|
||||
"namespace": {
|
||||
"param_type": "string",
|
||||
"description": "the name of the namespace",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"stream": streaming,
|
||||
}
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(**request)
|
||||
|
||||
if streaming:
|
||||
for chunk in response:
|
||||
delta = chunk.event.delta
|
||||
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
||||
assert delta.tool_call.tool_name == "get_object_namespace_list"
|
||||
if delta.type == "tool_call" and delta.parse_status == "failed":
|
||||
# expect raw message that failed to parse in tool_call
|
||||
assert type(delta.tool_call) == str
|
||||
assert len(delta.tool_call) > 0
|
||||
else:
|
||||
for tc in response.completion_message.tool_calls:
|
||||
assert tc.tool_name == "get_object_namespace_list"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue