mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 19:34:19 +00:00
# What does this PR do?
Contributes towards issue (#432)
- Groq text chat completions
- Streaming
- All the sampling params that Groq supports
A lot of inspiration taken from @mattf's good work at
https://github.com/meta-llama/llama-stack/pull/355
**What this PR does not do**
- Tool calls (Future PR)
- Adding llama-guard model
- See if we can add embeddings
### PR Train
- https://github.com/meta-llama/llama-stack/pull/609 👈
- https://github.com/meta-llama/llama-stack/pull/630
## Test Plan
<details>
<summary>Environment</summary>
```bash
export GROQ_API_KEY=<api_key>
wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml
# Build and run environment
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
--port 5001
```
</details>
<details>
<summary>Manual tests</summary>
Using this jupyter notebook to test manually:
2140976d76/hello.ipynb
Use this code to test passing in the api key from provider_data
```
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(
base_url="http://localhost:5001",
)
response = client.inference.chat_completion(
model_id="Llama3.2-3B-Instruct",
messages=[
{"role": "user", "content": "Hello, world client!"},
],
# Test passing in groq_api_key from the client
# Need to comment out the groq_api_key in the run.yaml file
x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}',
# stream=True,
)
response
```
</details>
<details>
<summary>Integration</summary>
`pytest llama_stack/providers/tests/inference/test_text_inference.py -v
-k groq`
(run in same environment)
```
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED [ 6%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED [ 37%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED [ 56%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED [ 87%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%]
======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ========================================
```
</details>
<details>
<summary>Unit tests</summary>
`pytest llama_stack/providers/tests/inference/groq/ -v`
```
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED [ 5%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED [ 40%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED [ 45%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED [ 60%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED [ 65%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED [ 95%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED [100%]
==================================================== 20 passed, 11 warnings in 0.08s =====================================================
```
</details>
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [x] Updated relevant documentation
- [x] Wrote necessary unit or integration tests.
150 lines
5.2 KiB
Python
150 lines
5.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import warnings
|
|
from typing import AsyncIterator, List, Optional, Union
|
|
|
|
from groq import Groq
|
|
from llama_models.datatypes import SamplingParams
|
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
|
from llama_models.sku_list import CoreModelId
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
InterleavedContent,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ToolChoice,
|
|
)
|
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
build_model_alias,
|
|
build_model_alias_with_just_provider_model_id,
|
|
ModelRegistryHelper,
|
|
)
|
|
from .groq_utils import (
|
|
convert_chat_completion_request,
|
|
convert_chat_completion_response,
|
|
convert_chat_completion_response_stream,
|
|
)
|
|
|
|
_MODEL_ALIASES = [
|
|
build_model_alias(
|
|
"llama3-8b-8192",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama-3.1-8b-instant",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3-70b-8192",
|
|
CoreModelId.llama3_70b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama-3.3-70b-versatile",
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
),
|
|
# Groq only contains a preview version for llama-3.2-3b
|
|
# Preview models aren't recommended for production use, but we include this one
|
|
# to pass the test fixture
|
|
# TODO(aidand): Replace this with a stable model once Groq supports it
|
|
build_model_alias(
|
|
"llama-3.2-3b-preview",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
]
|
|
|
|
|
|
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
|
_config: GroqConfig
|
|
|
|
def __init__(self, config: GroqConfig):
|
|
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
|
self._config = config
|
|
|
|
def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
# Groq doesn't support non-chat completion as of time of writing
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[
|
|
ToolPromptFormat
|
|
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[
|
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
]:
|
|
model_id = self.get_provider_model_id(model_id)
|
|
if model_id == "llama-3.2-3b-preview":
|
|
warnings.warn(
|
|
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
|
"Preview models aren't recommended for production use. "
|
|
"They can be discontinued on short notice."
|
|
)
|
|
|
|
request = convert_chat_completion_request(
|
|
request=ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
)
|
|
|
|
response = self._get_client().chat.completions.create(**request)
|
|
|
|
if stream:
|
|
return convert_chat_completion_response_stream(response)
|
|
else:
|
|
return convert_chat_completion_response(response)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
def _get_client(self) -> Groq:
|
|
if self._config.api_key is not None:
|
|
return Groq(api_key=self.config.api_key)
|
|
else:
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None or not provider_data.groq_api_key:
|
|
raise ValueError(
|
|
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
|
|
)
|
|
return Groq(api_key=provider_data.groq_api_key)
|