mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do?
Contributes towards issue (#432)
- Groq text chat completions
- Streaming
- All the sampling params that Groq supports
A lot of inspiration taken from @mattf's good work at
https://github.com/meta-llama/llama-stack/pull/355
**What this PR does not do**
- Tool calls (Future PR)
- Adding llama-guard model
- See if we can add embeddings
### PR Train
- https://github.com/meta-llama/llama-stack/pull/609 👈
- https://github.com/meta-llama/llama-stack/pull/630
## Test Plan
<details>
<summary>Environment</summary>
```bash
export GROQ_API_KEY=<api_key>
wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml
# Build and run environment
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
--port 5001
```
</details>
<details>
<summary>Manual tests</summary>
Using this jupyter notebook to test manually:
2140976d76/hello.ipynb
Use this code to test passing in the api key from provider_data
```
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(
base_url="http://localhost:5001",
)
response = client.inference.chat_completion(
model_id="Llama3.2-3B-Instruct",
messages=[
{"role": "user", "content": "Hello, world client!"},
],
# Test passing in groq_api_key from the client
# Need to comment out the groq_api_key in the run.yaml file
x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}',
# stream=True,
)
response
```
</details>
<details>
<summary>Integration</summary>
`pytest llama_stack/providers/tests/inference/test_text_inference.py -v
-k groq`
(run in same environment)
```
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED [ 6%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED [ 37%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED [ 56%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED [ 87%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%]
======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ========================================
```
</details>
<details>
<summary>Unit tests</summary>
`pytest llama_stack/providers/tests/inference/groq/ -v`
```
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED [ 5%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED [ 40%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED [ 45%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED [ 60%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED [ 65%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED [ 95%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED [100%]
==================================================== 20 passed, 11 warnings in 0.08s =====================================================
```
</details>
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [x] Updated relevant documentation
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
parent
e3f187fb83
commit
e1f42eb5a5
10 changed files with 692 additions and 0 deletions
|
@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
||||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
|
| Groq | Hosted | | :heavy_check_mark: | | | |
|
||||||
| Ollama | Single Node | | :heavy_check_mark: | | | |
|
| Ollama | Single Node | | :heavy_check_mark: | | | |
|
||||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
|
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||||
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||||
|
|
|
@ -154,6 +154,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="groq",
|
||||||
|
pip_packages=["groq"],
|
||||||
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
|
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
26
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
26
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import GroqConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GroqProviderDataValidator(BaseModel):
|
||||||
|
groq_api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so the import is used only when it is needed
|
||||||
|
from .groq import GroqInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, GroqConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
|
||||||
|
adapter = GroqInferenceAdapter(config)
|
||||||
|
return adapter
|
19
llama_stack/providers/remote/inference/groq/config.py
Normal file
19
llama_stack/providers/remote/inference/groq/config.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GroqConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||||
|
default=None,
|
||||||
|
description="The Groq API key",
|
||||||
|
)
|
150
llama_stack/providers/remote/inference/groq/groq.py
Normal file
150
llama_stack/providers/remote/inference/groq/groq.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
|
from groq import Groq
|
||||||
|
from llama_models.datatypes import SamplingParams
|
||||||
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_models.sku_list import CoreModelId
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
Inference,
|
||||||
|
InterleavedContent,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
build_model_alias_with_just_provider_model_id,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
from .groq_utils import (
|
||||||
|
convert_chat_completion_request,
|
||||||
|
convert_chat_completion_response,
|
||||||
|
convert_chat_completion_response_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
_MODEL_ALIASES = [
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-8b-8192",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias_with_just_provider_model_id(
|
||||||
|
"llama-3.1-8b-instant",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-70b-8192",
|
||||||
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.3-70b-versatile",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
# Groq only contains a preview version for llama-3.2-3b
|
||||||
|
# Preview models aren't recommended for production use, but we include this one
|
||||||
|
# to pass the test fixture
|
||||||
|
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.2-3b-preview",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
||||||
|
_config: GroqConfig
|
||||||
|
|
||||||
|
def __init__(self, config: GroqConfig):
|
||||||
|
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||||
|
# Groq doesn't support non-chat completion as of time of writing
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[
|
||||||
|
ToolPromptFormat
|
||||||
|
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[
|
||||||
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
|
]:
|
||||||
|
model_id = self.get_provider_model_id(model_id)
|
||||||
|
if model_id == "llama-3.2-3b-preview":
|
||||||
|
warnings.warn(
|
||||||
|
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
||||||
|
"Preview models aren't recommended for production use. "
|
||||||
|
"They can be discontinued on short notice."
|
||||||
|
)
|
||||||
|
|
||||||
|
request = convert_chat_completion_request(
|
||||||
|
request=ChatCompletionRequest(
|
||||||
|
model=model_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._get_client().chat.completions.create(**request)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return convert_chat_completion_response_stream(response)
|
||||||
|
else:
|
||||||
|
return convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedContent],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _get_client(self) -> Groq:
|
||||||
|
if self._config.api_key is not None:
|
||||||
|
return Groq(api_key=self.config.api_key)
|
||||||
|
else:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.groq_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
|
||||||
|
)
|
||||||
|
return Groq(api_key=provider_data.groq_api_key)
|
153
llama_stack/providers/remote/inference/groq/groq_utils.py
Normal file
153
llama_stack/providers/remote/inference/groq/groq_utils.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import AsyncGenerator, Literal
|
||||||
|
|
||||||
|
from groq import Stream
|
||||||
|
from groq.types.chat.chat_completion import ChatCompletion
|
||||||
|
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
)
|
||||||
|
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
from groq.types.chat.chat_completion_system_message_param import (
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
)
|
||||||
|
from groq.types.chat.chat_completion_user_message_param import (
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
|
Message,
|
||||||
|
StopReason,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_chat_completion_request(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> CompletionCreateParams:
|
||||||
|
"""
|
||||||
|
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
|
||||||
|
Warns client if request contains unsupported features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
# Groq doesn't support logprobs at the time of writing
|
||||||
|
warnings.warn("logprobs are not supported yet")
|
||||||
|
|
||||||
|
if request.response_format:
|
||||||
|
# Groq's JSON mode is beta at the time of writing
|
||||||
|
warnings.warn("response_format is not supported yet")
|
||||||
|
|
||||||
|
if request.sampling_params.repetition_penalty != 1.0:
|
||||||
|
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
|
||||||
|
# seem to have different semantics
|
||||||
|
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
|
||||||
|
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
|
||||||
|
# so we exclude it for now
|
||||||
|
warnings.warn("repetition_penalty is not supported")
|
||||||
|
|
||||||
|
if request.tools:
|
||||||
|
warnings.warn("tools are not supported yet")
|
||||||
|
|
||||||
|
return CompletionCreateParams(
|
||||||
|
model=request.model,
|
||||||
|
messages=[_convert_message(message) for message in request.messages],
|
||||||
|
logprobs=None,
|
||||||
|
frequency_penalty=None,
|
||||||
|
stream=request.stream,
|
||||||
|
max_tokens=request.sampling_params.max_tokens or None,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||||
|
if message.role == "system":
|
||||||
|
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||||
|
elif message.role == "user":
|
||||||
|
return ChatCompletionUserMessageParam(role="user", content=message.content)
|
||||||
|
elif message.role == "assistant":
|
||||||
|
return ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content=message.content
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid message role: {message.role}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_chat_completion_response(
|
||||||
|
response: ChatCompletion,
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
# groq only supports n=1 at time of writing, so there is only one choice
|
||||||
|
choice = response.choices[0]
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content,
|
||||||
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_finish_reason_to_stop_reason(
|
||||||
|
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||||
|
) -> StopReason:
|
||||||
|
"""
|
||||||
|
Convert a Groq chat completion finish_reason to a StopReason.
|
||||||
|
|
||||||
|
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||||
|
- stop -> model hit a natural stop point or a provided stop sequence
|
||||||
|
- length -> maximum number of tokens specified in the request was reached
|
||||||
|
- tool_calls -> model called a tool
|
||||||
|
"""
|
||||||
|
if finish_reason == "stop":
|
||||||
|
return StopReason.end_of_turn
|
||||||
|
elif finish_reason == "length":
|
||||||
|
return StopReason.out_of_tokens
|
||||||
|
elif finish_reason == "tool_calls":
|
||||||
|
raise NotImplementedError("tool_calls is not supported yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_chat_completion_response_stream(
|
||||||
|
stream: Stream[ChatCompletionChunk],
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
|
||||||
|
event_type = ChatCompletionResponseEventType.start
|
||||||
|
for chunk in stream:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
|
# We assume there's only one finish_reason for the entire stream.
|
||||||
|
# We collect the last finish_reason
|
||||||
|
if choice.finish_reason:
|
||||||
|
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=choice.delta.content or "",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
logprobs=None,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||||
|
@ -151,6 +152,22 @@ def inference_together() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_groq() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="groq",
|
||||||
|
provider_type="remote::groq",
|
||||||
|
config=GroqConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_data=dict(
|
||||||
|
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_bedrock() -> ProviderFixture:
|
def inference_bedrock() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -236,6 +253,7 @@ INFERENCE_FIXTURES = [
|
||||||
"ollama",
|
"ollama",
|
||||||
"fireworks",
|
"fireworks",
|
||||||
"together",
|
"together",
|
||||||
|
"groq",
|
||||||
"vllm_remote",
|
"vllm_remote",
|
||||||
"remote",
|
"remote",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
|
|
271
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
271
llama_stack/providers/tests/inference/groq/test_groq_utils.py
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
|
from groq.types.chat.chat_completion_chunk import (
|
||||||
|
ChatCompletionChunk,
|
||||||
|
Choice as StreamChoice,
|
||||||
|
ChoiceDelta,
|
||||||
|
)
|
||||||
|
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
CompletionMessage,
|
||||||
|
StopReason,
|
||||||
|
SystemMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||||
|
convert_chat_completion_request,
|
||||||
|
convert_chat_completion_response,
|
||||||
|
convert_chat_completion_response_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertChatCompletionRequest:
|
||||||
|
def test_sets_model(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.model = "Llama-3.2-3B"
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["model"] == "Llama-3.2-3B"
|
||||||
|
|
||||||
|
def test_converts_user_message(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.messages = [UserMessage(content="Hello World")]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["messages"] == [
|
||||||
|
{"role": "user", "content": "Hello World"},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_converts_system_message(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["messages"] == [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_converts_completion_message(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.messages = [
|
||||||
|
UserMessage(content="Hello World"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="Hello World! How can I help you today?",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["messages"] == [
|
||||||
|
{"role": "user", "content": "Hello World"},
|
||||||
|
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_does_not_include_logprobs(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.logprobs = True
|
||||||
|
|
||||||
|
with pytest.warns(Warning) as warnings:
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
||||||
|
assert converted.get("logprobs") is None
|
||||||
|
|
||||||
|
def test_does_not_include_response_format(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.response_format = {
|
||||||
|
"type": "json_object",
|
||||||
|
"json_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.warns(Warning) as warnings:
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
||||||
|
assert converted.get("response_format") is None
|
||||||
|
|
||||||
|
def test_does_not_include_repetition_penalty(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.sampling_params.repetition_penalty = 1.5
|
||||||
|
|
||||||
|
with pytest.warns(Warning) as warnings:
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
||||||
|
assert converted.get("repetition_penalty") is None
|
||||||
|
assert converted.get("frequency_penalty") is None
|
||||||
|
|
||||||
|
def test_includes_stream(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.stream = True
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["stream"] is True
|
||||||
|
|
||||||
|
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
# 0 is the default value for max_tokens
|
||||||
|
# So we assume that if it's 0, the user didn't set it
|
||||||
|
request.sampling_params.max_tokens = 0
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted.get("max_tokens") is None
|
||||||
|
|
||||||
|
def test_includes_max_tokens_if_set(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.sampling_params.max_tokens = 100
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["max_tokens"] == 100
|
||||||
|
|
||||||
|
def _dummy_chat_completion_request(self):
|
||||||
|
return ChatCompletionRequest(
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
messages=[UserMessage(content="Hello World")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_includes_temperature(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.sampling_params.temperature = 0.5
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["temperature"] == 0.5
|
||||||
|
|
||||||
|
def test_includes_top_p(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.sampling_params.top_p = 0.95
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["top_p"] == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertNonStreamChatCompletionResponse:
|
||||||
|
def test_returns_response(self):
|
||||||
|
response = self._dummy_chat_completion_response()
|
||||||
|
response.choices[0].message.content = "Hello World"
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert converted.completion_message.content == "Hello World"
|
||||||
|
|
||||||
|
def test_maps_stop_to_end_of_message(self):
|
||||||
|
response = self._dummy_chat_completion_response()
|
||||||
|
response.choices[0].finish_reason = "stop"
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
|
def test_maps_length_to_end_of_message(self):
|
||||||
|
response = self._dummy_chat_completion_response()
|
||||||
|
response.choices[0].finish_reason = "length"
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
|
def _dummy_chat_completion_response(self):
|
||||||
|
return ChatCompletion(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
role="assistant", content="Hello World"
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729382400,
|
||||||
|
object="chat.completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertStreamChatCompletionResponse:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_stream(self):
|
||||||
|
def chat_completion_stream():
|
||||||
|
messages = ["Hello ", "World ", " !"]
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
chunk = self._dummy_chat_completion_chunk()
|
||||||
|
chunk.choices[0].delta.content = message
|
||||||
|
if i == len(messages) - 1:
|
||||||
|
chunk.choices[0].finish_reason = "stop"
|
||||||
|
else:
|
||||||
|
chunk.choices[0].finish_reason = None
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
chunk = self._dummy_chat_completion_chunk()
|
||||||
|
chunk.choices[0].delta.content = None
|
||||||
|
chunk.choices[0].finish_reason = "stop"
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = chat_completion_stream()
|
||||||
|
converted = convert_chat_completion_response_stream(stream)
|
||||||
|
|
||||||
|
iter = converted.__aiter__()
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||||
|
assert chunk.event.delta == "Hello "
|
||||||
|
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||||
|
assert chunk.event.delta == "World "
|
||||||
|
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||||
|
assert chunk.event.delta == " !"
|
||||||
|
|
||||||
|
# Dummy chunk to ensure the last chunk is really the end of the stream
|
||||||
|
# This one technically maps to Groq's final "stop" chunk
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||||
|
assert chunk.event.delta == ""
|
||||||
|
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||||
|
assert chunk.event.delta == ""
|
||||||
|
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await iter.__anext__()
|
||||||
|
|
||||||
|
def _dummy_chat_completion_chunk(self):
|
||||||
|
return ChatCompletionChunk(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
choices=[
|
||||||
|
StreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729382400,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
x_groq=None,
|
||||||
|
)
|
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
29
llama_stack/providers/tests/inference/groq/test_init.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.providers.remote.inference.groq import get_adapter_impl
|
||||||
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroqInit:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_runtime_error_if_config_is_not_groq_config(self):
|
||||||
|
config = OllamaImplConfig(model="llama3.1-8b-8192")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await get_adapter_impl(config, None)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_groq_adapter(self):
|
||||||
|
config = GroqConfig()
|
||||||
|
adapter = await get_adapter_impl(config, None)
|
||||||
|
assert type(adapter) is GroqInferenceAdapter
|
||||||
|
assert isinstance(adapter, Inference)
|
|
@ -371,6 +371,14 @@ class TestInference:
|
||||||
sample_messages,
|
sample_messages,
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
):
|
):
|
||||||
|
inference_impl, _ = inference_stack
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||||
|
pytest.skip(
|
||||||
|
provider.__provider_spec__.provider_type
|
||||||
|
+ " doesn't support tool calling yet"
|
||||||
|
)
|
||||||
|
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -411,6 +419,13 @@ class TestInference:
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||||
|
pytest.skip(
|
||||||
|
provider.__provider_spec__.provider_type
|
||||||
|
+ " doesn't support tool calling yet"
|
||||||
|
)
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What's the weather like in San Francisco?",
|
content="What's the weather like in San Francisco?",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue