forked from phoenix-oss/llama-stack-mirror
llama-models should have extremely minimal cruft. Its sole purpose should be didactic -- show the simplest implementation of the llama models and document the prompt formats, etc. This PR is the complement to https://github.com/meta-llama/llama-models/pull/279 ## Test Plan Ensure all `llama` CLI `model` sub-commands work: ```bash llama model list llama model download --model-id ... llama model prompt-format -m ... ``` Ran tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/ LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/ LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/ ``` Create a fresh venv `uv venv && source .venv/bin/activate` and run `llama stack build --template fireworks --image-type venv` followed by `llama stack run together --image-type venv` <-- the server runs Also checked that the OpenAPI generator can run and there is no change in the generated files as a result. ```bash cd docs/openapi_generator sh run_openapi_generator.sh ```
575 lines
20 KiB
Python
575 lines
20 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
|
from groq.types.chat.chat_completion_chunk import (
|
|
ChatCompletionChunk,
|
|
ChoiceDelta,
|
|
ChoiceDeltaToolCall,
|
|
ChoiceDeltaToolCallFunction,
|
|
)
|
|
from groq.types.chat.chat_completion_chunk import (
|
|
Choice as StreamChoice,
|
|
)
|
|
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
|
from groq.types.chat.chat_completion_message_tool_call import (
|
|
ChatCompletionMessageToolCall,
|
|
Function,
|
|
)
|
|
from groq.types.shared.function_definition import FunctionDefinition
|
|
|
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponseEventType,
|
|
CompletionMessage,
|
|
StopReason,
|
|
SystemMessage,
|
|
ToolCall,
|
|
ToolChoice,
|
|
ToolDefinition,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy
|
|
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
|
convert_chat_completion_request,
|
|
convert_chat_completion_response,
|
|
convert_chat_completion_response_stream,
|
|
)
|
|
|
|
|
|
class TestConvertChatCompletionRequest:
|
|
def test_sets_model(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.model = "Llama-3.2-3B"
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["model"] == "Llama-3.2-3B"
|
|
|
|
def test_converts_user_message(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.messages = [UserMessage(content="Hello World")]
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["messages"] == [
|
|
{"role": "user", "content": "Hello World"},
|
|
]
|
|
|
|
def test_converts_system_message(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["messages"] == [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
]
|
|
|
|
def test_converts_completion_message(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.messages = [
|
|
UserMessage(content="Hello World"),
|
|
CompletionMessage(
|
|
content="Hello World! How can I help you today?",
|
|
stop_reason=StopReason.end_of_message,
|
|
),
|
|
]
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["messages"] == [
|
|
{"role": "user", "content": "Hello World"},
|
|
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
|
]
|
|
|
|
def test_does_not_include_logprobs(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.logprobs = True
|
|
|
|
with pytest.warns(Warning) as warnings:
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
|
assert converted.get("logprobs") is None
|
|
|
|
def test_does_not_include_response_format(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.response_format = {
|
|
"type": "json_object",
|
|
"json_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"age": {"type": "number"},
|
|
},
|
|
},
|
|
}
|
|
|
|
with pytest.warns(Warning) as warnings:
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
|
assert converted.get("response_format") is None
|
|
|
|
def test_does_not_include_repetition_penalty(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.sampling_params.repetition_penalty = 1.5
|
|
|
|
with pytest.warns(Warning) as warnings:
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
|
assert converted.get("repetition_penalty") is None
|
|
assert converted.get("frequency_penalty") is None
|
|
|
|
def test_includes_stream(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.stream = True
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["stream"] is True
|
|
|
|
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
|
request = self._dummy_chat_completion_request()
|
|
# 0 is the default value for max_tokens
|
|
# So we assume that if it's 0, the user didn't set it
|
|
request.sampling_params.max_tokens = 0
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted.get("max_tokens") is None
|
|
|
|
def test_includes_max_tokens_if_set(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.sampling_params.max_tokens = 100
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["max_tokens"] == 100
|
|
|
|
def _dummy_chat_completion_request(self):
|
|
return ChatCompletionRequest(
|
|
model="Llama-3.2-3B",
|
|
messages=[UserMessage(content="Hello World")],
|
|
)
|
|
|
|
def test_includes_stratgy(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95)
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["temperature"] == 0.5
|
|
assert converted["top_p"] == 0.95
|
|
|
|
def test_includes_greedy_strategy(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.sampling_params.strategy = GreedySamplingStrategy()
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["temperature"] == 0.0
|
|
|
|
def test_includes_tool_choice(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.tool_config.tool_choice = ToolChoice.required
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["tool_choice"] == "required"
|
|
|
|
def test_includes_tools(self):
|
|
request = self._dummy_chat_completion_request()
|
|
request.tools = [
|
|
ToolDefinition(
|
|
tool_name="get_flight_info",
|
|
description="Get fight information between two destinations.",
|
|
parameters={
|
|
"origin": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The origin airport code. E.g., AU",
|
|
required=True,
|
|
),
|
|
"destination": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The destination airport code. E.g., 'LAX'",
|
|
required=True,
|
|
),
|
|
"passengers": ToolParamDefinition(
|
|
param_type="array",
|
|
description="The passengers",
|
|
required=False,
|
|
),
|
|
},
|
|
),
|
|
ToolDefinition(
|
|
tool_name="log",
|
|
description="Calulate the logarithm of a number",
|
|
parameters={
|
|
"number": ToolParamDefinition(
|
|
param_type="float",
|
|
description="The number to calculate the logarithm of",
|
|
required=True,
|
|
),
|
|
"base": ToolParamDefinition(
|
|
param_type="integer",
|
|
description="The base of the logarithm",
|
|
required=False,
|
|
default=10,
|
|
),
|
|
},
|
|
),
|
|
]
|
|
|
|
converted = convert_chat_completion_request(request)
|
|
|
|
assert converted["tools"] == [
|
|
{
|
|
"type": "function",
|
|
"function": FunctionDefinition(
|
|
name="get_flight_info",
|
|
description="Get fight information between two destinations.",
|
|
parameters={
|
|
"origin": {
|
|
"type": "string",
|
|
"description": "The origin airport code. E.g., AU",
|
|
"required": True,
|
|
},
|
|
"destination": {
|
|
"type": "string",
|
|
"description": "The destination airport code. E.g., 'LAX'",
|
|
"required": True,
|
|
},
|
|
"passengers": {
|
|
"type": "array",
|
|
"description": "The passengers",
|
|
"required": False,
|
|
},
|
|
},
|
|
),
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": FunctionDefinition(
|
|
name="log",
|
|
description="Calulate the logarithm of a number",
|
|
parameters={
|
|
"number": {
|
|
"type": "float",
|
|
"description": "The number to calculate the logarithm of",
|
|
"required": True,
|
|
},
|
|
"base": {
|
|
"type": "integer",
|
|
"description": "The base of the logarithm",
|
|
"required": False,
|
|
"default": 10,
|
|
},
|
|
},
|
|
),
|
|
},
|
|
]
|
|
|
|
|
|
class TestConvertNonStreamChatCompletionResponse:
|
|
def test_returns_response(self):
|
|
response = self._dummy_chat_completion_response()
|
|
response.choices[0].message.content = "Hello World"
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert converted.completion_message.content == "Hello World"
|
|
|
|
def test_maps_stop_to_end_of_message(self):
|
|
response = self._dummy_chat_completion_response()
|
|
response.choices[0].finish_reason = "stop"
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
|
|
|
def test_maps_length_to_end_of_message(self):
|
|
response = self._dummy_chat_completion_response()
|
|
response.choices[0].finish_reason = "length"
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
|
|
|
def test_maps_tool_call_to_end_of_message(self):
|
|
response = self._dummy_chat_completion_response_with_tool_call()
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
|
|
|
def test_converts_multiple_tool_calls(self):
|
|
response = self._dummy_chat_completion_response_with_tool_call()
|
|
response.choices[0].message.tool_calls = [
|
|
ChatCompletionMessageToolCall(
|
|
id="tool_call_id",
|
|
type="function",
|
|
function=Function(
|
|
name="get_flight_info",
|
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
|
),
|
|
),
|
|
ChatCompletionMessageToolCall(
|
|
id="tool_call_id_2",
|
|
type="function",
|
|
function=Function(
|
|
name="log",
|
|
arguments='{"number": 10, "base": 2}',
|
|
),
|
|
),
|
|
]
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert converted.completion_message.tool_calls == [
|
|
ToolCall(
|
|
call_id="tool_call_id",
|
|
tool_name="get_flight_info",
|
|
arguments={"origin": "AU", "destination": "LAX"},
|
|
),
|
|
ToolCall(
|
|
call_id="tool_call_id_2",
|
|
tool_name="log",
|
|
arguments={"number": 10, "base": 2},
|
|
),
|
|
]
|
|
|
|
def test_converts_unparseable_tool_calls(self):
|
|
response = self._dummy_chat_completion_response_with_tool_call()
|
|
response.choices[0].message.tool_calls = [
|
|
ChatCompletionMessageToolCall(
|
|
id="tool_call_id",
|
|
type="function",
|
|
function=Function(
|
|
name="log",
|
|
arguments="(number=10, base=2)",
|
|
),
|
|
),
|
|
]
|
|
|
|
converted = convert_chat_completion_response(response)
|
|
|
|
assert (
|
|
converted.completion_message.content
|
|
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
|
|
)
|
|
|
|
def _dummy_chat_completion_response(self):
|
|
return ChatCompletion(
|
|
id="chatcmpl-123",
|
|
model="Llama-3.2-3B",
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(role="assistant", content="Hello World"),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
created=1729382400,
|
|
object="chat.completion",
|
|
)
|
|
|
|
def _dummy_chat_completion_response_with_tool_call(self):
|
|
return ChatCompletion(
|
|
id="chatcmpl-123",
|
|
model="Llama-3.2-3B",
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="tool_call_id",
|
|
type="function",
|
|
function=Function(
|
|
name="get_flight_info",
|
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
|
),
|
|
)
|
|
],
|
|
),
|
|
finish_reason="tool_calls",
|
|
)
|
|
],
|
|
created=1729382400,
|
|
object="chat.completion",
|
|
)
|
|
|
|
|
|
class TestConvertStreamChatCompletionResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_stream(self):
|
|
def chat_completion_stream():
|
|
messages = ["Hello ", "World ", " !"]
|
|
for i, message in enumerate(messages):
|
|
chunk = self._dummy_chat_completion_chunk()
|
|
chunk.choices[0].delta.content = message
|
|
yield chunk
|
|
|
|
chunk = self._dummy_chat_completion_chunk()
|
|
chunk.choices[0].delta.content = None
|
|
chunk.choices[0].finish_reason = "stop"
|
|
yield chunk
|
|
|
|
stream = chat_completion_stream()
|
|
converted = convert_chat_completion_response_stream(stream)
|
|
|
|
iter = converted.__aiter__()
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
|
assert chunk.event.delta.text == "Hello "
|
|
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
|
assert chunk.event.delta.text == "World "
|
|
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
|
assert chunk.event.delta.text == " !"
|
|
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
|
assert chunk.event.delta.text == ""
|
|
assert chunk.event.stop_reason == StopReason.end_of_turn
|
|
|
|
with pytest.raises(StopAsyncIteration):
|
|
await iter.__anext__()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_tool_calls_stream(self):
|
|
def tool_call_stream():
|
|
tool_calls = [
|
|
ToolCall(
|
|
call_id="tool_call_id",
|
|
tool_name="get_flight_info",
|
|
arguments={"origin": "AU", "destination": "LAX"},
|
|
),
|
|
ToolCall(
|
|
call_id="tool_call_id_2",
|
|
tool_name="log",
|
|
arguments={"number": 10, "base": 2},
|
|
),
|
|
]
|
|
for i, tool_call in enumerate(tool_calls):
|
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
|
chunk.choices[0].delta.tool_calls = [
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
type="function",
|
|
id=tool_call.call_id,
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name=tool_call.tool_name,
|
|
arguments=json.dumps(tool_call.arguments),
|
|
),
|
|
),
|
|
]
|
|
yield chunk
|
|
|
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
|
chunk.choices[0].delta.content = None
|
|
chunk.choices[0].finish_reason = "stop"
|
|
yield chunk
|
|
|
|
stream = tool_call_stream()
|
|
converted = convert_chat_completion_response_stream(stream)
|
|
|
|
iter = converted.__aiter__()
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
|
assert chunk.event.delta.tool_call == ToolCall(
|
|
call_id="tool_call_id",
|
|
tool_name="get_flight_info",
|
|
arguments={"origin": "AU", "destination": "LAX"},
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
|
|
def tool_call_stream():
|
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
|
chunk.choices[0].delta.tool_calls = [
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
type="function",
|
|
id="tool_call_id",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="get_flight_info",
|
|
arguments="(origin=AU, destination=LAX)",
|
|
),
|
|
),
|
|
]
|
|
yield chunk
|
|
|
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
|
chunk.choices[0].delta.content = None
|
|
chunk.choices[0].finish_reason = "stop"
|
|
yield chunk
|
|
|
|
stream = tool_call_stream()
|
|
converted = convert_chat_completion_response_stream(stream)
|
|
|
|
iter = converted.__aiter__()
|
|
chunk = await iter.__anext__()
|
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
|
assert (
|
|
chunk.event.delta.content
|
|
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
|
|
)
|
|
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
|
|
|
|
def _dummy_chat_completion_chunk(self):
|
|
return ChatCompletionChunk(
|
|
id="chatcmpl-123",
|
|
model="Llama-3.2-3B",
|
|
choices=[
|
|
StreamChoice(
|
|
index=0,
|
|
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
|
)
|
|
],
|
|
created=1729382400,
|
|
object="chat.completion.chunk",
|
|
x_groq=None,
|
|
)
|
|
|
|
def _dummy_chat_completion_chunk_with_tool_call(self):
|
|
return ChatCompletionChunk(
|
|
id="chatcmpl-123",
|
|
model="Llama-3.2-3B",
|
|
choices=[
|
|
StreamChoice(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="Hello World",
|
|
tool_calls=[
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
type="function",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="get_flight_info",
|
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
|
),
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
created=1729382400,
|
|
object="chat.completion.chunk",
|
|
x_groq=None,
|
|
)
|