mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 02:40:01 +00:00
PR tool call followups
This commit is contained in:
parent
1f60c0286d
commit
76e08cfde0
3 changed files with 126 additions and 30 deletions
|
|
@ -23,6 +23,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
|||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
|
|
@ -347,6 +348,26 @@ class TestConvertNonStreamChatCompletionResponse:
|
|||
),
|
||||
]
|
||||
|
||||
def test_converts_unparseable_tool_calls(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
response.choices[0].message.tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="log",
|
||||
arguments="(number=10, base=2)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert (
|
||||
converted.completion_message.content
|
||||
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
|
|
@ -478,6 +499,40 @@ class TestConvertStreamChatCompletionResponse:
|
|||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
|
||||
def tool_call_stream():
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.tool_calls = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id="tool_call_id",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="get_flight_info",
|
||||
arguments="(origin=AU, destination=LAX)",
|
||||
),
|
||||
),
|
||||
]
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = tool_call_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert (
|
||||
chunk.event.delta.content
|
||||
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
|
||||
)
|
||||
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
|
|
|
|||
|
|
@ -377,13 +377,6 @@ class TestInference:
|
|||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if (
|
||||
provider.__provider_spec__.provider_type == "remote::groq"
|
||||
and "Llama-3.2" in inference_model
|
||||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue