unit test + fireworks streaming

This commit is contained in:
Xi Yan 2025-02-10 21:47:10 -08:00
parent 0f062a15ec
commit aa04867d3a
4 changed files with 115 additions and 13 deletions

View file

@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin):
if delta.type == "tool_call": if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded: if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.tool_call) tool_calls.append(delta.tool_call)
elif delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call
if stream: if stream:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(

View file

@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter): async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
yield chunk yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:

View file

@ -251,7 +251,9 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response( async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -334,7 +336,6 @@ async def process_chat_completion_stream_response(
# parse tool calls and report errors # parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason) message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
print(f"Parse TOOL CALLS message: {message}")
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
@ -349,17 +350,33 @@ async def process_chat_completion_stream_response(
) )
) )
request_tools = {t.tool_name: t for t in request.tools}
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk( if tool_call.tool_name in request_tools:
event=ChatCompletionResponseEvent( yield ChatCompletionResponseStreamChunk(
event_type=ChatCompletionResponseEventType.progress, event=ChatCompletionResponseEvent(
delta=ToolCallDelta( event_type=ChatCompletionResponseEventType.progress,
tool_call=tool_call, delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded, tool_call=tool_call,
), parse_status=ToolCallParseStatus.succeeded,
stop_reason=stop_reason, ),
stop_reason=stop_reason,
)
)
else:
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
# Parsing tool call failed due to tool call not being found in request tools,
# We still add the raw message text inside tool_call for responding back to the user
tool_call=buffer,
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
) )
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
"question,expected", "question,expected",
[ [
("Which planet do humans live on?", "Earth"), ("Which planet do humans live on?", "Earth"),
("Which planet has rings around it with a name starting with letter S?", "Saturn"), (
"Which planet has rings around it with a name starting with letter S?",
"Saturn",
),
], ],
) )
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
assert answer.last_name == "Jordan" assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963 assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15 assert answer.num_seasons_in_nba == 15
@pytest.mark.parametrize(
"streaming",
[
True,
False,
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
"model_id": text_model_id,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?",
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed",
},
}
],
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3",
},
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": True,
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": True,
},
},
}
],
"tool_choice": "auto",
"tool_prompt_format": tool_prompt_format,
"stream": streaming,
}
response = llama_stack_client.inference.chat_completion(**request)
if streaming:
for chunk in response:
delta = chunk.event.delta
if delta.type == "tool_call" and delta.parse_status == "succeeded":
assert delta.tool_call.tool_name == "get_object_namespace_list"
if delta.type == "tool_call" and delta.parse_status == "failed":
# expect raw message that failed to parse in tool_call
assert type(delta.tool_call) == str
assert len(delta.tool_call) > 0
else:
for tc in response.completion_message.tool_calls:
assert tc.tool_name == "get_object_namespace_list"