mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
unit test + fireworks streaming
This commit is contained in:
parent
0f062a15ec
commit
aa04867d3a
4 changed files with 115 additions and 13 deletions
|
@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if delta.type == "tool_call":
|
if delta.type == "tool_call":
|
||||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_calls.append(delta.tool_call)
|
tool_calls.append(delta.tool_call)
|
||||||
|
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||||
|
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||||
|
content = delta.tool_call
|
||||||
if stream:
|
if stream:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
|
@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
|
|
@ -251,7 +251,9 @@ async def process_completion_stream_response(
|
||||||
|
|
||||||
|
|
||||||
async def process_chat_completion_stream_response(
|
async def process_chat_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
|
formatter: ChatFormat,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -334,7 +336,6 @@ async def process_chat_completion_stream_response(
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||||
print(f"Parse TOOL CALLS message: {message}")
|
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
@ -349,17 +350,33 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_tools = {t.tool_name: t for t in request.tools}
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
if tool_call.tool_name in request_tools:
|
||||||
event=ChatCompletionResponseEvent(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event=ChatCompletionResponseEvent(
|
||||||
delta=ToolCallDelta(
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
tool_call=tool_call,
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
tool_call=tool_call,
|
||||||
),
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
stop_reason=stop_reason,
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
# Parsing tool call failed due to tool call not being found in request tools,
|
||||||
|
# We still add the raw message text inside tool_call for responding back to the user
|
||||||
|
tool_call=buffer,
|
||||||
|
parse_status=ToolCallParseStatus.failed,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
|
@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
||||||
"question,expected",
|
"question,expected",
|
||||||
[
|
[
|
||||||
("Which planet do humans live on?", "Earth"),
|
("Which planet do humans live on?", "Earth"),
|
||||||
("Which planet has rings around it with a name starting with letter S?", "Saturn"),
|
(
|
||||||
|
"Which planet has rings around it with a name starting with letter S?",
|
||||||
|
"Saturn",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
|
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
|
||||||
|
@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
||||||
assert answer.last_name == "Jordan"
|
assert answer.last_name == "Jordan"
|
||||||
assert answer.year_of_birth == 1963
|
assert answer.year_of_birth == 1963
|
||||||
assert answer.num_seasons_in_nba == 15
|
assert answer.num_seasons_in_nba == 15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"streaming",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
|
||||||
|
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||||
|
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||||
|
request = {
|
||||||
|
"model_id": text_model_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What pods are in the namespace openshift-lightspeed?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"stop_reason": "end_of_turn",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"call_id": "1",
|
||||||
|
"tool_name": "get_object_namespace_list",
|
||||||
|
"arguments": {
|
||||||
|
"kind": "pod",
|
||||||
|
"namespace": "openshift-lightspeed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"call_id": "1",
|
||||||
|
"tool_name": "get_object_namespace_list",
|
||||||
|
"content": "the objects are pod1, pod2, pod3",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"tool_name": "get_object_namespace_list",
|
||||||
|
"description": "Get the list of objects in a namespace",
|
||||||
|
"parameters": {
|
||||||
|
"kind": {
|
||||||
|
"param_type": "string",
|
||||||
|
"description": "the type of object",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"namespace": {
|
||||||
|
"param_type": "string",
|
||||||
|
"description": "the name of the namespace",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tool_prompt_format": tool_prompt_format,
|
||||||
|
"stream": streaming,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = llama_stack_client.inference.chat_completion(**request)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
for chunk in response:
|
||||||
|
delta = chunk.event.delta
|
||||||
|
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
||||||
|
assert delta.tool_call.tool_name == "get_object_namespace_list"
|
||||||
|
if delta.type == "tool_call" and delta.parse_status == "failed":
|
||||||
|
# expect raw message that failed to parse in tool_call
|
||||||
|
assert type(delta.tool_call) == str
|
||||||
|
assert len(delta.tool_call) > 0
|
||||||
|
else:
|
||||||
|
for tc in response.completion_message.tool_calls:
|
||||||
|
assert tc.tool_name == "get_object_namespace_list"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue