mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:02:36 +00:00
Adds groq inference adapter
This commit is contained in:
parent
4e6c984c26
commit
d8d0f4600d
10 changed files with 810 additions and 0 deletions
|
|
@ -161,4 +161,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="groq",
|
||||||
|
pip_packages=[
|
||||||
|
"openai",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
|
config_class="llama_stack.providers.remote.inference.groq.GroqImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
24
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
24
llama_stack/providers/remote/inference/groq/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import GroqImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GroqProviderDataValidator(BaseModel):
|
||||||
|
groq_api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: GroqImplConfig, _deps):
|
||||||
|
from .groq import GroqInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, GroqImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = GroqInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
29
llama_stack/providers/remote/inference/groq/config.py
Normal file
29
llama_stack/providers/remote/inference/groq/config.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GroqImplConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default="https://api.groq.com/openai/v1",
|
||||||
|
description="The URL for the Groq API server",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Groq API Key",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "https://api.groq.com/openai/v1",
|
||||||
|
"api_key": "${env.GROQ_API_KEY}",
|
||||||
|
}
|
||||||
464
llama_stack/providers/remote/inference/groq/groq.py
Normal file
464
llama_stack/providers/remote/inference/groq/groq.py
Normal file
|
|
@ -0,0 +1,464 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
AsyncIterator,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
Inference,
|
||||||
|
InterleavedTextMedia,
|
||||||
|
LogProbConfig,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelRegistryHelper,
|
||||||
|
build_model_alias,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import GroqImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GroqErrorCode(str, Enum):
|
||||||
|
INVALID_AUTH = "invalid_authentication"
|
||||||
|
RATE_LIMIT_EXCEEDED = "rate_limit_exceeded"
|
||||||
|
QUOTA_EXCEEDED = "quota_exceeded"
|
||||||
|
CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded"
|
||||||
|
INVALID_REQUEST = "invalid_request"
|
||||||
|
MODEL_NOT_FOUND = "model_not_found"
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ALIASES = [
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.2-1b-preview",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.2-3b-preview",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.2-11b-vision-preview",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.2-90b-vision-preview",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.1-8b-instant",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-3.1-70b-versatile",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-8b-8192",
|
||||||
|
CoreModelId.llama3_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-70b-8192",
|
||||||
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-groq-8b-8192-tool-use-preview",
|
||||||
|
CoreModelId.llama3_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3-groq-70b-8192-tool-use-preview",
|
||||||
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-guard-3-8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
UNSUPPORTED_PARAMS = {
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"response_format",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GroqInferenceAdapter(
|
||||||
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
"""Groq inference adapter using OpenAI client compatibility layer.
|
||||||
|
|
||||||
|
This adapter provides access to Groq's AI models through their OpenAI-compatible API.
|
||||||
|
It handles authentication, request formatting, and response processing while managing
|
||||||
|
unsupported features gracefully.
|
||||||
|
|
||||||
|
Supports tool/function calling for compatible models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GroqImplConfig) -> None:
|
||||||
|
"""Initialize the Groq inference adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration for the Groq implementation
|
||||||
|
"""
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||||
|
self.config = config
|
||||||
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
self._client: Optional[AsyncOpenAI] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> AsyncOpenAI:
|
||||||
|
"""Get or create the OpenAI client instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncOpenAI: The configured client instance
|
||||||
|
"""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
base_url=self.config.url,
|
||||||
|
api_key=self._get_api_key(),
|
||||||
|
timeout=60.0,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
"""Get the API key from config or request headers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The API key to use
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no API key is available
|
||||||
|
"""
|
||||||
|
if self.config.api_key is not None:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.groq_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.groq_api_key
|
||||||
|
|
||||||
|
def _filter_unsupported_params(self, params: Dict) -> Dict:
|
||||||
|
"""Remove parameters not supported by Groq API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Original parameters dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Filtered parameters dictionary
|
||||||
|
"""
|
||||||
|
return {k: v for k, v in params.items() if k not in UNSUPPORTED_PARAMS}
|
||||||
|
|
||||||
|
def _convert_tool_to_function(self, tool: ToolDefinition) -> dict:
|
||||||
|
"""Convert a ToolDefinition to Groq function format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: Tool definition to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Function definition in Groq format
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.tool_name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
name: {
|
||||||
|
"type": param.param_type,
|
||||||
|
"description": param.description,
|
||||||
|
}
|
||||||
|
for name, param in tool.parameters.items()
|
||||||
|
},
|
||||||
|
"required": list(tool.parameters.keys()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
|
) -> dict:
|
||||||
|
"""Prepare parameters for the API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The completion request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Prepared parameters for the API call
|
||||||
|
"""
|
||||||
|
sampling_options = get_sampling_options(request.sampling_params)
|
||||||
|
filtered_options = self._filter_unsupported_params(sampling_options)
|
||||||
|
|
||||||
|
if "temperature" in filtered_options:
|
||||||
|
filtered_options["temperature"] = min(
|
||||||
|
max(filtered_options["temperature"], 0), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dict = {}
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
input_dict["messages"] = [
|
||||||
|
{"role": message.role, "content": message.content}
|
||||||
|
for message in request.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
if request.tools:
|
||||||
|
input_dict["tools"] = [
|
||||||
|
self._convert_tool_to_function(tool)
|
||||||
|
for tool in request.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
if request.tool_choice == ToolChoice.auto:
|
||||||
|
input_dict["tool_choice"] = "auto"
|
||||||
|
elif request.tool_choice == ToolChoice.required:
|
||||||
|
input_dict["tool_choice"] = "required"
|
||||||
|
elif isinstance(request.tool_choice, str):
|
||||||
|
input_dict["tool_choice"] = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": request.tool_choice},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_dict["tool_choice"] = "none"
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = request.content
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": request.model,
|
||||||
|
**input_dict,
|
||||||
|
**filtered_options,
|
||||||
|
"stream": request.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _handle_groq_error(self, e: Exception) -> None:
|
||||||
|
"""Handle Groq specific API errors with detailed messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The exception to handle
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For client errors
|
||||||
|
RuntimeError: For server errors
|
||||||
|
"""
|
||||||
|
error_msg = str(e)
|
||||||
|
error_data = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(e, "response"):
|
||||||
|
error_data = e.response.json().get("error", {})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
error_code = error_data.get("code", "")
|
||||||
|
error_message = error_data.get("message", error_msg)
|
||||||
|
|
||||||
|
if "401" in error_msg or error_code == GroqErrorCode.INVALID_AUTH:
|
||||||
|
raise ValueError("Invalid API key or unauthorized access") from e
|
||||||
|
|
||||||
|
elif (
|
||||||
|
"429" in error_msg
|
||||||
|
or error_code == GroqErrorCode.RATE_LIMIT_EXCEEDED
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Rate limit exceeded. Please try again later"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
elif error_code == GroqErrorCode.QUOTA_EXCEEDED:
|
||||||
|
raise ValueError(
|
||||||
|
"API quota exceeded. Please check your usage limits"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
elif error_code == GroqErrorCode.CONTEXT_LENGTH_EXCEEDED:
|
||||||
|
raise ValueError(
|
||||||
|
"Total number of input and output tokens exceeds model's context length"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
elif error_code == GroqErrorCode.INVALID_REQUEST:
|
||||||
|
raise ValueError(f"Invalid request: {error_message}") from e
|
||||||
|
|
||||||
|
elif error_code == GroqErrorCode.MODEL_NOT_FOUND:
|
||||||
|
raise ValueError(f"Model not found: {error_message}") from e
|
||||||
|
|
||||||
|
raise RuntimeError(f"Groq API error: {error_message}") from e
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
try:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
response = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=response.choices[0].finish_reason,
|
||||||
|
text="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"tool_name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments or "",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=response.choices[0].finish_reason,
|
||||||
|
text=response.choices[0].message.content or "",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
compat_response = OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
return process_chat_completion_response(
|
||||||
|
compat_response, self.formatter
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._handle_groq_error(e)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
|
try:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
stream = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
async def _to_async_generator():
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=chunk.choices[0].finish_reason,
|
||||||
|
text="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"tool_name": tool_call.function.name
|
||||||
|
if tool_call.function
|
||||||
|
else None,
|
||||||
|
"arguments": tool_call.function.arguments
|
||||||
|
if tool_call.function
|
||||||
|
else "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if tool_call.function
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=chunk.choices[0].finish_reason,
|
||||||
|
text=chunk.choices[0].delta.content or "",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
_to_async_generator(), self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self._handle_groq_error(e)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[
|
||||||
|
CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]
|
||||||
|
]:
|
||||||
|
raise NotImplementedError("Groq does not support text completion")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
"""Handle chat completion requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
messages: List of chat messages
|
||||||
|
sampling_params: Parameters for text generation
|
||||||
|
tools: Tool definitions for function calling
|
||||||
|
tool_choice: Tool choice option
|
||||||
|
tool_prompt_format: Tool prompt format
|
||||||
|
response_format: Response format (not supported)
|
||||||
|
stream: Whether to stream the response
|
||||||
|
logprobs: Log probability config (not supported)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator: The completion response
|
||||||
|
"""
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
"""Embeddings are not supported.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Always raised as this feature is not supported
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Embeddings not supported for Groq")
|
||||||
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq import GroqImplConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
@ -156,6 +157,22 @@ def inference_nvidia() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_groq() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="groq",
|
||||||
|
provider_type="remote::groq",
|
||||||
|
config=GroqImplConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_data=dict(
|
||||||
|
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_short_name(model_name: str) -> str:
|
def get_model_short_name(model_name: str) -> str:
|
||||||
"""Convert model name to a short test identifier.
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
|
@ -190,6 +207,7 @@ INFERENCE_FIXTURES = [
|
||||||
"remote",
|
"remote",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
|
"groq",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
7
llama_stack/templates/groq/__init__.py
Normal file
7
llama_stack/templates/groq/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .groq import get_distribution_template # noqa: F401
|
||||||
19
llama_stack/templates/groq/build.yaml
Normal file
19
llama_stack/templates/groq/build.yaml
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
version: '2'
|
||||||
|
name: groq
|
||||||
|
distribution_spec:
|
||||||
|
description: Use Groq for running LLM inference
|
||||||
|
docker_image: null
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::groq
|
||||||
|
memory:
|
||||||
|
- inline::faiss
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
image_type: conda
|
||||||
66
llama_stack/templates/groq/doc_template.md
Normal file
66
llama_stack/templates/groq/doc_template.md
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Groq Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a Groq API Key. You can get one by signing up at [console.groq.com](https://console.groq.com).
|
||||||
|
|
||||||
|
## Running Llama Stack with Groq
|
||||||
|
|
||||||
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env GROQ_API_KEY=$GROQ_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template {{ name }} --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env GROQ_API_KEY=$GROQ_API_KEY
|
||||||
|
```
|
||||||
76
llama_stack/templates/groq/groq.py
Normal file
76
llama_stack/templates/groq/groq.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.remote.inference.groq import GroqImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq.groq import MODEL_ALIASES
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::groq"],
|
||||||
|
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="groq",
|
||||||
|
provider_type="remote::groq",
|
||||||
|
config=GroqImplConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
core_model_to_hf_repo = {
|
||||||
|
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||||
|
}
|
||||||
|
|
||||||
|
default_models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model],
|
||||||
|
provider_model_id=m.provider_model_id,
|
||||||
|
)
|
||||||
|
for m in MODEL_ALIASES
|
||||||
|
]
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="groq",
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Use Groq for running LLM inference",
|
||||||
|
docker_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
default_models=default_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_shields=[
|
||||||
|
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
|
||||||
|
],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"GROQ_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Groq API Key for authentication",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
95
llama_stack/templates/groq/run.yaml
Normal file
95
llama_stack/templates/groq/run.yaml
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: groq
|
||||||
|
docker_image: null
|
||||||
|
conda_env: groq
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- memory
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: groq
|
||||||
|
provider_type: remote::groq
|
||||||
|
config:
|
||||||
|
url: https://api.groq.com/openai/v1
|
||||||
|
api_key: ${env.GROQ_API_KEY}
|
||||||
|
memory:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.2-1b-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.2-3b-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.2-11b-vision-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.2-90b-vision-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.1-8b-instant
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-3.1-70b-versatile
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-8B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3-8b-8192
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-70B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3-70b-8192
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-8B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3-groq-8b-8192-tool-use-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-70B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3-groq-70b-8192-tool-use-preview
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama-guard-3-8b
|
||||||
|
shields: []
|
||||||
|
memory_banks: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
eval_tasks: []
|
||||||
Loading…
Add table
Add a link
Reference in a new issue