Add Groq provider - chat completions

This commit is contained in:
Aidan Do 2024-12-12 21:15:09 +11:00
parent c294a01c4b
commit 378150e23c
10 changed files with 727 additions and 31 deletions

View file

@ -22,10 +22,10 @@ Our goal is to provide pre-packaged implementations which can be operated in a v
> ⚠️ **Note** > ⚠️ **Note**
> The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions. > The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions.
## APIs ## APIs
We have working implementations of the following APIs today: We have working implementations of the following APIs today:
- Inference - Inference
- Safety - Safety
- Memory - Memory
@ -74,19 +74,21 @@ There is a vibrant ecosystem of Providers which provide efficient inference or s
Additionally, we have designed every element of the Stack such that APIs as well as Resources (like Models) can be federated. Additionally, we have designed every element of the Stack such that APIs as well as Resources (like Models) can be federated.
## Supported Llama Stack Implementations ## Supported Llama Stack Implementations
### API Providers ### API Providers
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------------------------------------------------------------------------:|:----------------------:|:------------------:|:------------------:|:------------------:|:------------------:|:------------------:| | :----------------------------------------------------------------------------------------: | :--------------------: | :----------------: | :----------------: | :----------------: | :----------------: | :----------------: |
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Cerebras | Hosted | | :heavy_check_mark: | | | | | Cerebras | Hosted | | :heavy_check_mark: | | | |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
| Ollama | Single Node | | :heavy_check_mark: | | | | | Groq | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | | | Ollama | Single Node | | :heavy_check_mark: | | |
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | | TGI | Hosted and Single Node | | :heavy_check_mark: | | |
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | |
| Chroma | Single Node | | | :heavy_check_mark: | | | | Chroma | Single Node | | | :heavy_check_mark: | | |
| PG Vector | Single Node | | | :heavy_check_mark: | | | | PG Vector | Single Node | | | :heavy_check_mark: | | |
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | | | PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | |
@ -95,7 +97,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
### Distributions ### Distributions
| **Distribution** | **Llama Stack Docker** | Start This Distribution | | **Distribution** | **Llama Stack Docker** | Start This Distribution |
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| | :------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | | Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
@ -111,6 +113,7 @@ You have two ways to install this repository:
1. **Install as a package**: 1. **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash ```bash
pip install llama-stack pip install llama-stack
``` ```
@ -118,6 +121,7 @@ You have two ways to install this repository:
2. **Install from source**: 2. **Install from source**:
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable). If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
Then, follow these steps: Then, follow these steps:
```bash ```bash
mkdir -p ~/local mkdir -p ~/local
cd ~/local cd ~/local
@ -134,24 +138,24 @@ You have two ways to install this repository:
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details. Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html) - [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution. - Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) - [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
* Quick guide to start a Llama Stack server. - Quick guide to start a Llama Stack server.
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs - [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack). - The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples. - A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
* [Contributing](CONTRIBUTING.md) - [Contributing](CONTRIBUTING.md)
* [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider. - [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider.
## Llama Stack Client SDKs ## Llama Stack Client SDKs
| **Language** | **Client SDK** | **Package** | | **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: | | :----------: | :----------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) | Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) |
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) | Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) |
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) |
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) | Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) |
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.

View file

@ -149,6 +149,15 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import GroqConfig
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .groq import GroqInferenceAdapter
if not isinstance(config, GroqConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = GroqInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)

View file

@ -0,0 +1,142 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncIterator, List, Optional, Union
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import CoreModelId
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
ResponseFormat,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
_MODEL_ALIASES = [
build_model_alias(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]
class GroqInferenceAdapter(Inference, ModelRegistryHelper):
_client: Groq
def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._client = Groq(api_key=config.api_key)
def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
)
model_id = self.get_provider_model_id(model_id)
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
)
response = self._client.chat.completions.create(**request)
if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncGenerator, Generator, Literal
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
Role,
StopReason,
)
def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
"""
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
Warns client if request contains unsupported features.
"""
if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")
if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")
if request.sampling_params.repetition_penalty:
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
# seem to have different semantics
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tools:
warnings.warn("tools are not supported yet")
return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
logprobs=None,
frequency_penalty=None,
stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
)
def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == Role.system.value:
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.role == Role.user.value:
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == Role.assistant.value:
return ChatCompletionAssistantMessageParam(
role="assistant", content=message.content
)
else:
raise ValueError(f"Invalid message role: {message.role}")
def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"]
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls"]
- stop -> model hit a natural stop point or a provided stop sequence
- length -> maximum number of tokens specified in the request was reached
- tool_calls -> model called a tool
"""
if finish_reason == "stop":
return StopReason.end_of_turn
elif finish_reason == "length":
return StopReason.end_of_message
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")
async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_types = _event_type_generator()
for chunk in stream:
choice = chunk.choices[0]
# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
delta=choice.delta.content or "",
logprobs=None,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)
)

View file

@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig
@ -150,6 +151,22 @@ def inference_together() -> ProviderFixture:
) )
@pytest.fixture(scope="session")
def inference_groq() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="groq",
provider_type="remote::groq",
config=GroqConfig().model_dump(),
)
],
provider_data=dict(
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
),
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture: def inference_bedrock() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
@ -222,6 +239,7 @@ INFERENCE_FIXTURES = [
"ollama", "ollama",
"fireworks", "fireworks",
"together", "together",
"groq",
"vllm_remote", "vllm_remote",
"remote", "remote",
"bedrock", "bedrock",

View file

@ -0,0 +1,278 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice as StreamChoice,
ChoiceDelta,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
StopReason,
SystemMessage,
UserMessage,
)
from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
class TestConvertChatCompletionRequest:
def test_sets_model(self):
request = self._dummy_chat_completion_request()
request.model = "Llama-3.2-3B"
converted = convert_chat_completion_request(request)
assert converted["model"] == "Llama-3.2-3B"
def test_converts_user_message(self):
request = self._dummy_chat_completion_request()
request.messages = [UserMessage(content="Hello World")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
]
def test_converts_system_message(self):
request = self._dummy_chat_completion_request()
request.messages = [SystemMessage(content="You are a helpful assistant.")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "system", "content": "You are a helpful assistant."},
]
def test_converts_completion_message(self):
request = self._dummy_chat_completion_request()
request.messages = [
UserMessage(content="Hello World"),
CompletionMessage(
content="Hello World! How can I help you today?",
stop_reason=StopReason.end_of_message,
),
]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
{"role": "assistant", "content": "Hello World! How can I help you today?"},
]
def test_does_not_include_logprobs(self):
request = self._dummy_chat_completion_request()
request.logprobs = True
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "logprobs are not supported yet" in warnings[0].message.args[0]
assert converted.get("logprobs") is None
def test_does_not_include_response_format(self):
request = self._dummy_chat_completion_request()
request.response_format = {
"type": "json_object",
"json_schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
},
}
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "response_format is not supported yet" in warnings[0].message.args[0]
assert converted.get("response_format") is None
def test_does_not_include_repetition_penalty(self):
request = self._dummy_chat_completion_request()
request.sampling_params.repetition_penalty = 1.5
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
assert converted.get("repetition_penalty") is None
assert converted.get("frequency_penalty") is None
def test_includes_stream(self):
request = self._dummy_chat_completion_request()
request.stream = True
converted = convert_chat_completion_request(request)
assert converted["stream"] is True
def test_n_is_1(self):
request = self._dummy_chat_completion_request()
converted = convert_chat_completion_request(request)
assert converted["n"] == 1
def test_if_max_tokens_is_0_then_it_is_not_included(self):
request = self._dummy_chat_completion_request()
# 0 is the default value for max_tokens
# So we assume that if it's 0, the user didn't set it
request.sampling_params.max_tokens = 0
converted = convert_chat_completion_request(request)
assert converted.get("max_tokens") is None
def test_includes_max_tokens_if_set(self):
request = self._dummy_chat_completion_request()
request.sampling_params.max_tokens = 100
converted = convert_chat_completion_request(request)
assert converted["max_tokens"] == 100
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
def test_includes_temperature(self):
request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5
def test_includes_top_p(self):
request = self._dummy_chat_completion_request()
request.sampling_params.top_p = 0.95
converted = convert_chat_completion_request(request)
assert converted["top_p"] == 0.95
class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self):
response = self._dummy_chat_completion_response()
response.choices[0].message.content = "Hello World"
converted = convert_chat_completion_response(response)
assert converted.completion_message.content == "Hello World"
def test_maps_stop_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "stop"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_turn
def test_maps_length_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "length"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_message
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Hello World"
),
finish_reason="stop",
)
],
created=1729382400,
object="chat.completion",
)
class TestConvertStreamChatCompletionResponse:
@pytest.mark.asyncio
async def test_returns_stream(self):
def chat_completion_stream():
messages = ["Hello ", "World ", " !"]
for i, message in enumerate(messages):
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = message
if i == len(messages) - 1:
chunk.choices[0].finish_reason = "stop"
else:
chunk.choices[0].finish_reason = None
yield chunk
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = chat_completion_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta == "Hello "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == "World "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == " !"
# Dummy chunk to ensure the last chunk is really the end of the stream
# This one technically maps to Groq's final "stop" chunk
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == ""
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta == ""
assert chunk.event.stop_reason == StopReason.end_of_turn
with pytest.raises(StopAsyncIteration):
await iter.__anext__()
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(role="assistant", content="Hello World"),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.providers.remote.inference.groq import get_adapter_impl
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
class TestGroqInit:
@pytest.mark.asyncio
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
config = OllamaImplConfig(model="llama3.1-8b-8192")
with pytest.raises(RuntimeError):
await get_adapter_impl(config, None)
@pytest.mark.asyncio
async def test_returns_groq_adapter(self):
config = GroqConfig()
adapter = await get_adapter_impl(config, None)
assert type(adapter) is GroqInferenceAdapter
assert isinstance(adapter, Inference)

View file

@ -350,6 +350,14 @@ class TestInference:
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
): ):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
@ -390,6 +398,13 @@ class TestInference:
sample_tool_definition, sample_tool_definition,
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
content="What's the weather like in San Francisco?", content="What's the weather like in San Francisco?",