forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Follow up for @ashwinb's comments in https://github.com/meta-llama/llama-stack/pull/630 - [x] Contributes to issue (#432) ## Test Plan <details> <summary>Environment</summary> ```shell export GROQ_API_KEY=<api-key> # Create environment if not already conda create --name llamastack-groq python=3.10 conda activate llamastack-groq wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml wget https://raw.githubusercontent.com/meta-llama/llama-stack/918172c7fa92522c9ebc586bdb4f386b1d9ea224/run.yaml # Build pip install -e . && llama stack build --config ./build.yaml --image-type conda # Activate built environment conda activate llamastack-groq # Test deps pip install pytest pytest_html pytest_asyncio ``` </details> <details> <summary>Unit tests</summary> ```shell # Setup conda activate llamastack-groq pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s # Result llama_stack/providers/tests/inference/groq/test_groq_utils.py ....................... ========================================= 23 passed, 11 warnings in 0.06s ========================================= ``` </details> <details> <summary>Integration tests</summary> ```shell # Tests pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s # Results ___________________________ TestInference.test_chat_completion_with_tool_calling[-groq] ___________________________ llama_stack/providers/tests/inference/test_text_inference.py:403: in test_chat_completion_with_tool_calling assert len(message.tool_calls) > 0 E assert 0 > 0 E + where 0 = len([]) E + where [] = CompletionMessage(role='assistant', content='<function=get_weather>{"location": "San Francisco, CA"}', stop_reason=<StopReason.end_of_turn: 'end_of_turn'>, tool_calls=[]).tool_calls ============================================= short test summary info ============================================= FAILED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-groq] - assert 0 > 0 ======================== 1 failed, 3 passed, 5 skipped, 99 deselected, 7 warnings in 2.13s ======================== ``` (One failure as expected from 3.2 3B - re: https://github.com/meta-llama/llama-stack/pull/630#discussion_r1914056503) </details> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
d5b7de3897
commit
39c34dd25f
2 changed files with 126 additions and 23 deletions
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Literal
|
from typing import AsyncGenerator, Literal, Union
|
||||||
|
|
||||||
from groq import Stream
|
from groq import Stream
|
||||||
from groq.types.chat.chat_completion import ChatCompletion
|
from groq.types.chat.chat_completion import ChatCompletion
|
||||||
|
@ -30,6 +30,8 @@ from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
|
@ -150,6 +152,17 @@ def convert_chat_completion_response(
|
||||||
_convert_groq_tool_call(tool_call)
|
_convert_groq_tool_call(tool_call)
|
||||||
for tool_call in choice.message.tool_calls
|
for tool_call in choice.message.tool_calls
|
||||||
]
|
]
|
||||||
|
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||||
|
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Otherwise, return tool calls as normal
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
|
@ -214,6 +227,7 @@ async def convert_chat_completion_response_stream(
|
||||||
|
|
||||||
# We assume Groq produces fully formed tool calls for each chunk
|
# We assume Groq produces fully formed tool calls for each chunk
|
||||||
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||||
|
if isinstance(tool_call, ToolCall):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -223,6 +237,17 @@ async def convert_chat_completion_response_stream(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Otherwise it's an UnparseableToolCall - return the raw tool call
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=tool_call.model_dump_json(),
|
||||||
|
parse_status=ToolCallParseStatus.failed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream(
|
||||||
event_type = ChatCompletionResponseEventType.progress
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
|
|
||||||
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
class UnparseableToolCall(BaseModel):
|
||||||
|
"""
|
||||||
|
A ToolCall with arguments that are not valid JSON.
|
||||||
|
Mirrors the ToolCall schema, but with arguments as a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
call_id: str
|
||||||
|
tool_name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_groq_tool_call(
|
||||||
|
tool_call: ChatCompletionMessageToolCall,
|
||||||
|
) -> Union[ToolCall, UnparseableToolCall]:
|
||||||
|
"""
|
||||||
|
Convert a Groq tool call to a ToolCall.
|
||||||
|
Returns an UnparseableToolCall if the tool call is not valid JSON.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_call.function.arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return UnparseableToolCall(
|
||||||
|
call_id=tool_call.id,
|
||||||
|
tool_name=tool_call.function.name,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
)
|
||||||
|
|
||||||
return ToolCall(
|
return ToolCall(
|
||||||
call_id=tool_call.id,
|
call_id=tool_call.id,
|
||||||
tool_name=tool_call.function.name,
|
tool_name=tool_call.function.name,
|
||||||
# Note that Groq may return a string that is not valid JSON here
|
arguments=arguments,
|
||||||
# So this may raise a 500 error. Going to leave this as is to see
|
|
||||||
# how big of an issue this is and what we can do about it.
|
|
||||||
arguments=json.loads(tool_call.function.arguments),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
from groq.types.shared.function_definition import FunctionDefinition
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -347,6 +348,26 @@ class TestConvertNonStreamChatCompletionResponse:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_converts_unparseable_tool_calls(self):
|
||||||
|
response = self._dummy_chat_completion_response_with_tool_call()
|
||||||
|
response.choices[0].message.tool_calls = [
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="tool_call_id",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="log",
|
||||||
|
arguments="(number=10, base=2)",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
converted.completion_message.content
|
||||||
|
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
|
||||||
|
)
|
||||||
|
|
||||||
def _dummy_chat_completion_response(self):
|
def _dummy_chat_completion_response(self):
|
||||||
return ChatCompletion(
|
return ChatCompletion(
|
||||||
id="chatcmpl-123",
|
id="chatcmpl-123",
|
||||||
|
@ -478,6 +499,40 @@ class TestConvertStreamChatCompletionResponse:
|
||||||
arguments={"origin": "AU", "destination": "LAX"},
|
arguments={"origin": "AU", "destination": "LAX"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
|
||||||
|
def tool_call_stream():
|
||||||
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||||
|
chunk.choices[0].delta.tool_calls = [
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
id="tool_call_id",
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="get_flight_info",
|
||||||
|
arguments="(origin=AU, destination=LAX)",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||||
|
chunk.choices[0].delta.content = None
|
||||||
|
chunk.choices[0].finish_reason = "stop"
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = tool_call_stream()
|
||||||
|
converted = convert_chat_completion_response_stream(stream)
|
||||||
|
|
||||||
|
iter = converted.__aiter__()
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||||
|
assert (
|
||||||
|
chunk.event.delta.content
|
||||||
|
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
|
||||||
|
)
|
||||||
|
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
|
||||||
|
|
||||||
def _dummy_chat_completion_chunk(self):
|
def _dummy_chat_completion_chunk(self):
|
||||||
return ChatCompletionChunk(
|
return ChatCompletionChunk(
|
||||||
id="chatcmpl-123",
|
id="chatcmpl-123",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue