mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:29:54 +00:00
Add Groq provider - chat completions
This commit is contained in:
parent
c294a01c4b
commit
378150e23c
10 changed files with 727 additions and 31 deletions
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
|||
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
|
|
@ -150,6 +151,22 @@ def inference_together() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_groq() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="groq",
|
||||
provider_type="remote::groq",
|
||||
config=GroqConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_bedrock() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
|
@ -222,6 +239,7 @@ INFERENCE_FIXTURES = [
|
|||
"ollama",
|
||||
"fireworks",
|
||||
"together",
|
||||
"groq",
|
||||
"vllm_remote",
|
||||
"remote",
|
||||
"bedrock",
|
||||
|
|
|
|||
278
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
278
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice as StreamChoice,
|
||||
ChoiceDelta,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatCompletionRequest:
|
||||
def test_sets_model(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.model = "Llama-3.2-3B"
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["model"] == "Llama-3.2-3B"
|
||||
|
||||
def test_converts_user_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [UserMessage(content="Hello World")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
]
|
||||
|
||||
def test_converts_system_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
def test_converts_completion_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [
|
||||
UserMessage(content="Hello World"),
|
||||
CompletionMessage(
|
||||
content="Hello World! How can I help you today?",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
||||
]
|
||||
|
||||
def test_does_not_include_logprobs(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.logprobs = True
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("logprobs") is None
|
||||
|
||||
def test_does_not_include_response_format(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.response_format = {
|
||||
"type": "json_object",
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("response_format") is None
|
||||
|
||||
def test_does_not_include_repetition_penalty(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.repetition_penalty = 1.5
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
||||
assert converted.get("repetition_penalty") is None
|
||||
assert converted.get("frequency_penalty") is None
|
||||
|
||||
def test_includes_stream(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.stream = True
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["stream"] is True
|
||||
|
||||
def test_n_is_1(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["n"] == 1
|
||||
|
||||
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
# 0 is the default value for max_tokens
|
||||
# So we assume that if it's 0, the user didn't set it
|
||||
request.sampling_params.max_tokens = 0
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted.get("max_tokens") is None
|
||||
|
||||
def test_includes_max_tokens_if_set(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.max_tokens = 100
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_temperature(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
|
||||
def test_includes_top_p(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].message.content = "Hello World"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.content == "Hello World"
|
||||
|
||||
def test_maps_stop_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "stop"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
||||
|
||||
def test_maps_length_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "length"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Hello World"
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
class TestConvertStreamChatCompletionResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_stream(self):
|
||||
def chat_completion_stream():
|
||||
messages = ["Hello ", "World ", " !"]
|
||||
for i, message in enumerate(messages):
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = message
|
||||
if i == len(messages) - 1:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = None
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = chat_completion_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta == "Hello "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == "World "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == " !"
|
||||
|
||||
# Dummy chunk to ensure the last chunk is really the end of the stream
|
||||
# This one technically maps to Groq's final "stop" chunk
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == ""
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta == ""
|
||||
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iter.__anext__()
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.providers.remote.inference.groq import get_adapter_impl
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
|
||||
|
||||
class TestGroqInit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
|
||||
config = OllamaImplConfig(model="llama3.1-8b-8192")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await get_adapter_impl(config, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_groq_adapter(self):
|
||||
config = GroqConfig()
|
||||
adapter = await get_adapter_impl(config, None)
|
||||
assert type(adapter) is GroqInferenceAdapter
|
||||
assert isinstance(adapter, Inference)
|
||||
|
|
@ -350,6 +350,14 @@ class TestInference:
|
|||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
|
|
@ -390,6 +398,13 @@ class TestInference:
|
|||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue