mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Add Groq provider - chat completions
This commit is contained in:
parent
c294a01c4b
commit
378150e23c
10 changed files with 727 additions and 31 deletions
66
README.md
66
README.md
|
@ -22,10 +22,10 @@ Our goal is to provide pre-packaged implementations which can be operated in a v
|
|||
> ⚠️ **Note**
|
||||
> The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions.
|
||||
|
||||
|
||||
## APIs
|
||||
|
||||
We have working implementations of the following APIs today:
|
||||
|
||||
- Inference
|
||||
- Safety
|
||||
- Memory
|
||||
|
@ -74,19 +74,21 @@ There is a vibrant ecosystem of Providers which provide efficient inference or s
|
|||
|
||||
Additionally, we have designed every element of the Stack such that APIs as well as Resources (like Models) can be federated.
|
||||
|
||||
|
||||
## Supported Llama Stack Implementations
|
||||
|
||||
### API Providers
|
||||
|
||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||
|:------------------------------------------------------------------------------------------:|:----------------------:|:------------------:|:------------------:|:------------------:|:------------------:|:------------------:|
|
||||
| :----------------------------------------------------------------------------------------: | :--------------------: | :----------------: | :----------------: | :----------------: | :----------------: | :----------------: |
|
||||
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Cerebras | Hosted | | :heavy_check_mark: | | | |
|
||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Ollama | Single Node | | :heavy_check_mark: | | | |
|
||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||
| Groq | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Ollama | Single Node | | :heavy_check_mark: | | |
|
||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | |
|
||||
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | |
|
||||
| Chroma | Single Node | | | :heavy_check_mark: | | |
|
||||
| PG Vector | Single Node | | | :heavy_check_mark: | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | |
|
||||
|
@ -94,16 +96,16 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
|||
|
||||
### Distributions
|
||||
|
||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
|
||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
|
||||
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
|
||||
| [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
|
||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||
| :------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------: |
|
||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
|
||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
|
||||
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
|
||||
| [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -111,6 +113,7 @@ You have two ways to install this repository:
|
|||
|
||||
1. **Install as a package**:
|
||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||
|
||||
```bash
|
||||
pip install llama-stack
|
||||
```
|
||||
|
@ -118,6 +121,7 @@ You have two ways to install this repository:
|
|||
2. **Install from source**:
|
||||
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
||||
Then, follow these steps:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
|
@ -134,24 +138,24 @@ You have two ways to install this repository:
|
|||
|
||||
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
|
||||
|
||||
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
|
||||
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
|
||||
* Quick guide to start a Llama Stack server.
|
||||
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
|
||||
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
|
||||
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
|
||||
* [Contributing](CONTRIBUTING.md)
|
||||
* [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider.
|
||||
- [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
|
||||
- Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||
- [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
|
||||
- Quick guide to start a Llama Stack server.
|
||||
- [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
|
||||
- The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
|
||||
- A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
|
||||
- [Contributing](CONTRIBUTING.md)
|
||||
- [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider.
|
||||
|
||||
## Llama Stack Client SDKs
|
||||
|
||||
| **Language** | **Client SDK** | **Package** |
|
||||
| :----: | :----: | :----: |
|
||||
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
|
||||
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
||||
| **Language** | **Client SDK** | **Package** |
|
||||
| :----------: | :----------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/) |
|
||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) |
|
||||
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client) |
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) |
|
||||
|
||||
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||
|
||||
|
|
|
@ -149,6 +149,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["groq"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
20
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
20
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
if not isinstance(config, GroqConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
return adapter
|
19
llama_stack/providers/remote/inference/groq/config.py
Normal file
19
llama_stack/providers/remote/inference/groq/config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
142
llama_stack/providers/remote/inference/groq/groq.py
Normal file
142
llama_stack/providers/remote/inference/groq/groq.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
from groq import Groq
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import CoreModelId
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
ResponseFormat,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from .groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_model_alias(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GroqInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
_client: Groq
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
self._client = Groq(api_key=config.api_key)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# Groq doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[
|
||||
ToolPromptFormat
|
||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
|
||||
if model_id == "llama-3.2-3b-preview":
|
||||
warnings.warn(
|
||||
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
||||
"Preview models aren't recommended for production use. "
|
||||
"They can be discontinued on short notice."
|
||||
)
|
||||
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
response = self._client.chat.completions.create(**request)
|
||||
|
||||
if stream:
|
||||
return convert_chat_completion_response_stream(response)
|
||||
else:
|
||||
return convert_chat_completion_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
162
llama_stack/providers/remote/inference/groq/groq_utils.py
Normal file
162
llama_stack/providers/remote/inference/groq/groq_utils.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Generator, Literal
|
||||
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
from groq.types.chat.chat_completion_system_message_param import (
|
||||
ChatCompletionSystemMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
Role,
|
||||
StopReason,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> CompletionCreateParams:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
|
||||
Warns client if request contains unsupported features.
|
||||
"""
|
||||
|
||||
if request.logprobs:
|
||||
# Groq doesn't support logprobs at the time of writing
|
||||
warnings.warn("logprobs are not supported yet")
|
||||
|
||||
if request.response_format:
|
||||
# Groq's JSON mode is beta at the time of writing
|
||||
warnings.warn("response_format is not supported yet")
|
||||
|
||||
if request.sampling_params.repetition_penalty:
|
||||
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
|
||||
# seem to have different semantics
|
||||
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
|
||||
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
|
||||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tools:
|
||||
warnings.warn("tools are not supported yet")
|
||||
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
logprobs=None,
|
||||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||
if message.role == Role.system.value:
|
||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||
elif message.role == Role.user.value:
|
||||
return ChatCompletionUserMessageParam(role="user", content=message.content)
|
||||
elif message.role == Role.assistant.value:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant", content=message.content
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
|
||||
def convert_chat_completion_response(
|
||||
response: ChatCompletion,
|
||||
) -> ChatCompletionResponse:
|
||||
# groq only supports n=1 at time of writing, so there is only one choice
|
||||
choice = response.choices[0]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
- stop -> model hit a natural stop point or a provided stop sequence
|
||||
- length -> maximum number of tokens specified in the request was reached
|
||||
- tool_calls -> model called a tool
|
||||
"""
|
||||
if finish_reason == "stop":
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
return StopReason.end_of_message
|
||||
elif finish_reason == "tool_calls":
|
||||
raise NotImplementedError("tool_calls is not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||
|
||||
|
||||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
def _event_type_generator() -> (
|
||||
Generator[ChatCompletionResponseEventType, None, None]
|
||||
):
|
||||
yield ChatCompletionResponseEventType.start
|
||||
while True:
|
||||
yield ChatCompletionResponseEventType.progress
|
||||
|
||||
event_types = _event_type_generator()
|
||||
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
# We assume there's only one finish_reason for the entire stream.
|
||||
# We collect the last finish_reason
|
||||
if choice.finish_reason:
|
||||
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_types),
|
||||
delta=choice.delta.content or "",
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
logprobs=None,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
|||
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
|
@ -150,6 +151,22 @@ def inference_together() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_groq() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="groq",
|
||||
provider_type="remote::groq",
|
||||
config=GroqConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_bedrock() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
@ -222,6 +239,7 @@ INFERENCE_FIXTURES = [
|
|||
"ollama",
|
||||
"fireworks",
|
||||
"together",
|
||||
"groq",
|
||||
"vllm_remote",
|
||||
"remote",
|
||||
"bedrock",
|
||||
|
|
278
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
278
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
|
@ -0,0 +1,278 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice as StreamChoice,
|
||||
ChoiceDelta,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatCompletionRequest:
|
||||
def test_sets_model(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.model = "Llama-3.2-3B"
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["model"] == "Llama-3.2-3B"
|
||||
|
||||
def test_converts_user_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [UserMessage(content="Hello World")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
]
|
||||
|
||||
def test_converts_system_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
def test_converts_completion_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [
|
||||
UserMessage(content="Hello World"),
|
||||
CompletionMessage(
|
||||
content="Hello World! How can I help you today?",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
||||
]
|
||||
|
||||
def test_does_not_include_logprobs(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.logprobs = True
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("logprobs") is None
|
||||
|
||||
def test_does_not_include_response_format(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.response_format = {
|
||||
"type": "json_object",
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("response_format") is None
|
||||
|
||||
def test_does_not_include_repetition_penalty(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.repetition_penalty = 1.5
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
||||
assert converted.get("repetition_penalty") is None
|
||||
assert converted.get("frequency_penalty") is None
|
||||
|
||||
def test_includes_stream(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.stream = True
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["stream"] is True
|
||||
|
||||
def test_n_is_1(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["n"] == 1
|
||||
|
||||
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
# 0 is the default value for max_tokens
|
||||
# So we assume that if it's 0, the user didn't set it
|
||||
request.sampling_params.max_tokens = 0
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted.get("max_tokens") is None
|
||||
|
||||
def test_includes_max_tokens_if_set(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.max_tokens = 100
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_temperature(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
|
||||
def test_includes_top_p(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].message.content = "Hello World"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.content == "Hello World"
|
||||
|
||||
def test_maps_stop_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "stop"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
||||
|
||||
def test_maps_length_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "length"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Hello World"
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
class TestConvertStreamChatCompletionResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_stream(self):
|
||||
def chat_completion_stream():
|
||||
messages = ["Hello ", "World ", " !"]
|
||||
for i, message in enumerate(messages):
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = message
|
||||
if i == len(messages) - 1:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = None
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = chat_completion_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta == "Hello "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == "World "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == " !"
|
||||
|
||||
# Dummy chunk to ensure the last chunk is really the end of the stream
|
||||
# This one technically maps to Groq's final "stop" chunk
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == ""
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta == ""
|
||||
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iter.__anext__()
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.providers.remote.inference.groq import get_adapter_impl
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
|
||||
|
||||
class TestGroqInit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
|
||||
config = OllamaImplConfig(model="llama3.1-8b-8192")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await get_adapter_impl(config, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_groq_adapter(self):
|
||||
config = GroqConfig()
|
||||
adapter = await get_adapter_impl(config, None)
|
||||
assert type(adapter) is GroqInferenceAdapter
|
||||
assert isinstance(adapter, Inference)
|
|
@ -350,6 +350,14 @@ class TestInference:
|
|||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
|
@ -390,6 +398,13 @@ class TestInference:
|
|||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue