diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 124e6b0fa..296b32e18 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -8923,6 +8923,9 @@
"OpenAIChatCompletionToolCall": {
"type": "object",
"properties": {
+ "index": {
+ "type": "integer"
+ },
"id": {
"type": "string"
},
@@ -8937,9 +8940,7 @@
},
"additionalProperties": false,
"required": [
- "id",
- "type",
- "function"
+ "type"
],
"title": "OpenAIChatCompletionToolCall"
},
@@ -8954,10 +8955,6 @@
}
},
"additionalProperties": false,
- "required": [
- "name",
- "arguments"
- ],
"title": "OpenAIChatCompletionToolCallFunction"
},
"OpenAIDeveloperMessageParam": {
@@ -9563,7 +9560,7 @@
"choices": {
"type": "array",
"items": {
- "$ref": "#/components/schemas/OpenAIChoice"
+ "$ref": "#/components/schemas/OpenAIChunkChoice"
},
"description": "List of choices"
},
@@ -9605,10 +9602,12 @@
"description": "The reason the model stopped generating"
},
"index": {
- "type": "integer"
+ "type": "integer",
+ "description": "The index of the choice"
},
"logprobs": {
- "$ref": "#/components/schemas/OpenAIChoiceLogprobs"
+ "$ref": "#/components/schemas/OpenAIChoiceLogprobs",
+ "description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
@@ -9620,6 +9619,33 @@
"title": "OpenAIChoice",
"description": "A choice from an OpenAI-compatible chat completion response."
},
+ "OpenAIChoiceDelta": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "(Optional) The content of the delta"
+ },
+ "refusal": {
+ "type": "string",
+ "description": "(Optional) The refusal of the delta"
+ },
+ "role": {
+ "type": "string",
+ "description": "(Optional) The role of the delta"
+ },
+ "tool_calls": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
+ },
+ "description": "(Optional) The tool calls of the delta"
+ }
+ },
+ "additionalProperties": false,
+ "title": "OpenAIChoiceDelta",
+ "description": "A delta from an OpenAI-compatible chat completion streaming response."
+ },
"OpenAIChoiceLogprobs": {
"type": "object",
"properties": {
@@ -9627,19 +9653,50 @@
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAITokenLogProb"
- }
+ },
+ "description": "(Optional) The log probabilities for the tokens in the message"
},
"refusal": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAITokenLogProb"
- }
+ },
+ "description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
"title": "OpenAIChoiceLogprobs",
"description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."
},
+ "OpenAIChunkChoice": {
+ "type": "object",
+ "properties": {
+ "delta": {
+ "$ref": "#/components/schemas/OpenAIChoiceDelta",
+ "description": "The delta from the chunk"
+ },
+ "finish_reason": {
+ "type": "string",
+ "description": "The reason the model stopped generating"
+ },
+ "index": {
+ "type": "integer",
+ "description": "The index of the choice"
+ },
+ "logprobs": {
+ "$ref": "#/components/schemas/OpenAIChoiceLogprobs",
+ "description": "(Optional) The log probabilities for the tokens in the message"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "delta",
+ "finish_reason",
+ "index"
+ ],
+ "title": "OpenAIChunkChoice",
+ "description": "A chunk choice from an OpenAI-compatible chat completion streaming response."
+ },
"OpenAITokenLogProb": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 781fbc618..7a983ccc0 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6127,6 +6127,8 @@ components:
OpenAIChatCompletionToolCall:
type: object
properties:
+ index:
+ type: integer
id:
type: string
type:
@@ -6137,9 +6139,7 @@ components:
$ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction'
additionalProperties: false
required:
- - id
- type
- - function
title: OpenAIChatCompletionToolCall
OpenAIChatCompletionToolCallFunction:
type: object
@@ -6149,9 +6149,6 @@ components:
arguments:
type: string
additionalProperties: false
- required:
- - name
- - arguments
title: OpenAIChatCompletionToolCallFunction
OpenAIDeveloperMessageParam:
type: object
@@ -6550,7 +6547,7 @@ components:
choices:
type: array
items:
- $ref: '#/components/schemas/OpenAIChoice'
+ $ref: '#/components/schemas/OpenAIChunkChoice'
description: List of choices
object:
type: string
@@ -6587,8 +6584,11 @@ components:
description: The reason the model stopped generating
index:
type: integer
+ description: The index of the choice
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
+ description: >-
+ (Optional) The log probabilities for the tokens in the message
additionalProperties: false
required:
- message
@@ -6597,6 +6597,27 @@ components:
title: OpenAIChoice
description: >-
A choice from an OpenAI-compatible chat completion response.
+ OpenAIChoiceDelta:
+ type: object
+ properties:
+ content:
+ type: string
+ description: (Optional) The content of the delta
+ refusal:
+ type: string
+ description: (Optional) The refusal of the delta
+ role:
+ type: string
+ description: (Optional) The role of the delta
+ tool_calls:
+ type: array
+ items:
+ $ref: '#/components/schemas/OpenAIChatCompletionToolCall'
+ description: (Optional) The tool calls of the delta
+ additionalProperties: false
+ title: OpenAIChoiceDelta
+ description: >-
+ A delta from an OpenAI-compatible chat completion streaming response.
OpenAIChoiceLogprobs:
type: object
properties:
@@ -6604,15 +6625,43 @@ components:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
+ description: >-
+ (Optional) The log probabilities for the tokens in the message
refusal:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
+ description: >-
+ (Optional) The log probabilities for the tokens in the message
additionalProperties: false
title: OpenAIChoiceLogprobs
description: >-
The log probabilities for the tokens in the message from an OpenAI-compatible
chat completion response.
+ OpenAIChunkChoice:
+ type: object
+ properties:
+ delta:
+ $ref: '#/components/schemas/OpenAIChoiceDelta'
+ description: The delta from the chunk
+ finish_reason:
+ type: string
+ description: The reason the model stopped generating
+ index:
+ type: integer
+ description: The index of the choice
+ logprobs:
+ $ref: '#/components/schemas/OpenAIChoiceLogprobs'
+ description: >-
+ (Optional) The log probabilities for the tokens in the message
+ additionalProperties: false
+ required:
+ - delta
+ - finish_reason
+ - index
+ title: OpenAIChunkChoice
+ description: >-
+ A chunk choice from an OpenAI-compatible chat completion streaming response.
OpenAITokenLogProb:
type: object
properties:
diff --git a/docs/source/distributions/self_hosted_distro/groq.md b/docs/source/distributions/self_hosted_distro/groq.md
index 4f5a8a859..b18be1b2f 100644
--- a/docs/source/distributions/self_hosted_distro/groq.md
+++ b/docs/source/distributions/self_hosted_distro/groq.md
@@ -43,7 +43,9 @@ The following models are available by default:
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
+- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
+- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
### Prerequisite: API Keys
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 1f3e64dd6..596efb136 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -503,15 +503,16 @@ class OpenAISystemMessageParam(BaseModel):
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
- name: str
- arguments: str
+ name: Optional[str] = None
+ arguments: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
- id: str
+ index: Optional[int] = None
+ id: Optional[str] = None
type: Literal["function"] = "function"
- function: OpenAIChatCompletionToolCallFunction
+ function: Optional[OpenAIChatCompletionToolCallFunction] = None
@json_schema_type
@@ -645,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
- :content: (Optional) The log probabilities for the tokens in the message
- :refusal: (Optional) The log probabilities for the tokens in the message
+ :param content: (Optional) The log probabilities for the tokens in the message
+ :param refusal: (Optional) The log probabilities for the tokens in the message
"""
content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None
+@json_schema_type
+class OpenAIChoiceDelta(BaseModel):
+ """A delta from an OpenAI-compatible chat completion streaming response.
+
+ :param content: (Optional) The content of the delta
+ :param refusal: (Optional) The refusal of the delta
+ :param role: (Optional) The role of the delta
+ :param tool_calls: (Optional) The tool calls of the delta
+ """
+
+ content: Optional[str] = None
+ refusal: Optional[str] = None
+ role: Optional[str] = None
+ tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
+
+
+@json_schema_type
+class OpenAIChunkChoice(BaseModel):
+ """A chunk choice from an OpenAI-compatible chat completion streaming response.
+
+ :param delta: The delta from the chunk
+ :param finish_reason: The reason the model stopped generating
+ :param index: The index of the choice
+ :param logprobs: (Optional) The log probabilities for the tokens in the message
+ """
+
+ delta: OpenAIChoiceDelta
+ finish_reason: str
+ index: int
+ logprobs: Optional[OpenAIChoiceLogprobs] = None
+
+
@json_schema_type
class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response.
:param message: The message from the model
:param finish_reason: The reason the model stopped generating
- :index: The index of the choice
- :logprobs: (Optional) The log probabilities for the tokens in the message
+ :param index: The index of the choice
+ :param logprobs: (Optional) The log probabilities for the tokens in the message
"""
message: OpenAIMessageParam
@@ -699,7 +732,7 @@ class OpenAIChatCompletionChunk(BaseModel):
"""
id: str
- choices: List[OpenAIChoice]
+ choices: List[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py
index c8789434f..f3f14e9af 100644
--- a/llama_stack/providers/remote/inference/groq/groq.py
+++ b/llama_stack/providers/remote/inference/groq/groq.py
@@ -4,8 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+from typing import Any, AsyncIterator, Dict, List, Optional, Union
+
+from openai import AsyncOpenAI
+
+from llama_stack.apis.inference.inference import (
+ OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
+ OpenAIChoiceDelta,
+ OpenAIChunkChoice,
+ OpenAIMessageParam,
+ OpenAIResponseFormatParam,
+ OpenAISystemMessageParam,
+)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
+from llama_stack.providers.utils.inference.openai_compat import (
+ prepare_openai_completion_params,
+)
from .models import MODEL_ENTRIES
@@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key",
)
self.config = config
+ self._openai_client = None
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
+ if self._openai_client:
+ await self._openai_client.close()
+ self._openai_client = None
+
+ def _get_openai_client(self) -> AsyncOpenAI:
+ if not self._openai_client:
+ self._openai_client = AsyncOpenAI(
+ base_url=f"{self.config.url}/openai/v1",
+ api_key=self.config.api_key,
+ )
+ return self._openai_client
+
+ async def openai_chat_completion(
+ self,
+ model: str,
+ messages: List[OpenAIMessageParam],
+ frequency_penalty: Optional[float] = None,
+ function_call: Optional[Union[str, Dict[str, Any]]] = None,
+ functions: Optional[List[Dict[str, Any]]] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ logprobs: Optional[bool] = None,
+ max_completion_tokens: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ presence_penalty: Optional[float] = None,
+ response_format: Optional[OpenAIResponseFormatParam] = None,
+ seed: Optional[int] = None,
+ stop: Optional[Union[str, List[str]]] = None,
+ stream: Optional[bool] = None,
+ stream_options: Optional[Dict[str, Any]] = None,
+ temperature: Optional[float] = None,
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ top_logprobs: Optional[int] = None,
+ top_p: Optional[float] = None,
+ user: Optional[str] = None,
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
+ model_obj = await self.model_store.get_model(model)
+
+ # Groq does not support json_schema response format, so we need to convert it to json_object
+ if response_format and response_format.type == "json_schema":
+ response_format.type = "json_object"
+ schema = response_format.json_schema.get("schema", {})
+ response_format.json_schema = None
+ json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
+ if messages and messages[0].role == "system":
+ messages[0].content = messages[0].content + json_instructions
+ else:
+ messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
+
+ # Groq returns a 400 error if tools are provided but none are called
+ # So, set tool_choice to "required" to attempt to force a call
+ if tools and (not tool_choice or tool_choice == "auto"):
+ tool_choice = "required"
+
+ params = await prepare_openai_completion_params(
+ model=model_obj.provider_resource_id.replace("groq/", ""),
+ messages=messages,
+ frequency_penalty=frequency_penalty,
+ function_call=function_call,
+ functions=functions,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ max_completion_tokens=max_completion_tokens,
+ max_tokens=max_tokens,
+ n=n,
+ parallel_tool_calls=parallel_tool_calls,
+ presence_penalty=presence_penalty,
+ response_format=response_format,
+ seed=seed,
+ stop=stop,
+ stream=stream,
+ stream_options=stream_options,
+ temperature=temperature,
+ tool_choice=tool_choice,
+ tools=tools,
+ top_logprobs=top_logprobs,
+ top_p=top_p,
+ user=user,
+ )
+
+ # Groq does not support streaming requests that set response_format
+ fake_stream = False
+ if stream and response_format:
+ params["stream"] = False
+ fake_stream = True
+
+ response = await self._get_openai_client().chat.completions.create(**params)
+
+ if fake_stream:
+ chunk_choices = []
+ for choice in response.choices:
+ delta = OpenAIChoiceDelta(
+ content=choice.message.content,
+ role=choice.message.role,
+ tool_calls=choice.message.tool_calls,
+ )
+ chunk_choice = OpenAIChunkChoice(
+ delta=delta,
+ finish_reason=choice.finish_reason,
+ index=choice.index,
+ logprobs=None,
+ )
+ chunk_choices.append(chunk_choice)
+ chunk = OpenAIChatCompletionChunk(
+ id=response.id,
+ choices=chunk_choices,
+ object="chat.completion.chunk",
+ created=response.created,
+ model=response.model,
+ )
+
+ async def _fake_stream_generator():
+ yield chunk
+
+ return _fake_stream_generator()
+ else:
+ return response
diff --git a/llama_stack/providers/remote/inference/groq/models.py b/llama_stack/providers/remote/inference/groq/models.py
index d0c10ca62..0b4b81cfe 100644
--- a/llama_stack/providers/remote/inference/groq/models.py
+++ b/llama_stack/providers/remote/inference/groq/models.py
@@ -39,8 +39,16 @@ MODEL_ENTRIES = [
"groq/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
+ build_hf_repo_model_entry(
+ "groq/meta-llama/llama-4-scout-17b-16e-instruct",
+ CoreModelId.llama4_scout_17b_16e_instruct.value,
+ ),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
+ build_hf_repo_model_entry(
+ "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
+ CoreModelId.llama4_maverick_17b_128e_instruct.value,
+ ),
]
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 95e8b767b..efe7031f5 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -298,7 +298,7 @@ class LiteLLMOpenAIMixin(
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
- return litellm.text_completion(**params)
+ return await litellm.atext_completion(**params)
async def openai_chat_completion(
self,
@@ -352,7 +352,7 @@ class LiteLLMOpenAIMixin(
top_p=top_p,
user=user,
)
- return litellm.completion(**params)
+ return await litellm.acompletion(**params)
async def batch_completion(
self,
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 2fcfa341e..d98261abb 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -1354,14 +1354,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
i = 0
async for chunk in response:
event = chunk.event
- if event.stop_reason == StopReason.end_of_turn:
- finish_reason = "stop"
- elif event.stop_reason == StopReason.end_of_message:
- finish_reason = "eos"
- elif event.stop_reason == StopReason.out_of_tokens:
- finish_reason = "length"
- else:
- finish_reason = None
+ finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml
index ea3b7252a..0dd056405 100644
--- a/llama_stack/templates/dev/run.yaml
+++ b/llama_stack/templates/dev/run.yaml
@@ -386,6 +386,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
@@ -396,6 +406,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml
index f557e64fd..444452dcb 100644
--- a/llama_stack/templates/groq/run.yaml
+++ b/llama_stack/templates/groq/run.yaml
@@ -158,6 +158,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
@@ -168,6 +178,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml
index b6c2ca98d..454ecba5b 100644
--- a/llama_stack/templates/verification/run.yaml
+++ b/llama_stack/templates/verification/run.yaml
@@ -474,6 +474,16 @@ models:
provider_id: groq-openai-compat
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ provider_id: groq-openai-compat
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
+ provider_id: groq-openai-compat
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq-openai-compat
@@ -484,6 +494,16 @@ models:
provider_id: groq-openai-compat
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
+- metadata: {}
+ model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ provider_id: groq-openai-compat
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
+ provider_id: groq-openai-compat
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
- metadata: {}
model_id: Meta-Llama-3.1-8B-Instruct
provider_id: sambanova-openai-compat
diff --git a/tests/verifications/conf/groq-llama-stack.yaml b/tests/verifications/conf/groq-llama-stack.yaml
new file mode 100644
index 000000000..fd5e9abec
--- /dev/null
+++ b/tests/verifications/conf/groq-llama-stack.yaml
@@ -0,0 +1,14 @@
+base_url: http://localhost:8321/v1/openai/v1
+api_key_var: GROQ_API_KEY
+models:
+- groq/llama-3.3-70b-versatile
+- groq/llama-4-scout-17b-16e-instruct
+- groq/llama-4-maverick-17b-128e-instruct
+model_display_names:
+ groq/llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
+ groq/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
+ groq/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
+test_exclusions:
+ groq/llama-3.3-70b-versatile:
+ - test_chat_non_streaming_image
+ - test_chat_streaming_image
diff --git a/tests/verifications/conf/groq.yaml b/tests/verifications/conf/groq.yaml
index 7871036dc..76b1244ae 100644
--- a/tests/verifications/conf/groq.yaml
+++ b/tests/verifications/conf/groq.yaml
@@ -2,12 +2,12 @@ base_url: https://api.groq.com/openai/v1
api_key_var: GROQ_API_KEY
models:
- llama-3.3-70b-versatile
-- llama-4-scout-17b-16e-instruct
-- llama-4-maverick-17b-128e-instruct
+- meta-llama/llama-4-scout-17b-16e-instruct
+- meta-llama/llama-4-maverick-17b-128e-instruct
model_display_names:
llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
- llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
- llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
+ meta-llama/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
+ meta-llama/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
llama-3.3-70b-versatile:
- test_chat_non_streaming_image
diff --git a/tests/verifications/conf/openai-llama-stack.yaml b/tests/verifications/conf/openai-llama-stack.yaml
index ee116dcf0..de35439ae 100644
--- a/tests/verifications/conf/openai-llama-stack.yaml
+++ b/tests/verifications/conf/openai-llama-stack.yaml
@@ -1,9 +1,9 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: OPENAI_API_KEY
models:
-- gpt-4o
-- gpt-4o-mini
+- openai/gpt-4o
+- openai/gpt-4o-mini
model_display_names:
- gpt-4o: gpt-4o
- gpt-4o-mini: gpt-4o-mini
+ openai/gpt-4o: gpt-4o
+ openai/gpt-4o-mini: gpt-4o-mini
test_exclusions: {}
diff --git a/tests/verifications/generate_report.py b/tests/verifications/generate_report.py
index c1eac8a33..b39c3fd19 100755
--- a/tests/verifications/generate_report.py
+++ b/tests/verifications/generate_report.py
@@ -75,6 +75,7 @@ PROVIDER_ORDER = [
"openai",
"together-llama-stack",
"fireworks-llama-stack",
+ "groq-llama-stack",
"openai-llama-stack",
]
diff --git a/tests/verifications/openai-api-verification-run.yaml b/tests/verifications/openai-api-verification-run.yaml
index 0e8b99e4f..71885d058 100644
--- a/tests/verifications/openai-api-verification-run.yaml
+++ b/tests/verifications/openai-api-verification-run.yaml
@@ -17,6 +17,11 @@ providers:
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
+ - provider_id: groq
+ provider_type: remote::groq
+ config:
+ url: https://api.groq.com
+ api_key: ${env.GROQ_API_KEY}
- provider_id: openai
provider_type: remote::openai
config:
@@ -98,6 +103,21 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
+- metadata: {}
+ model_id: groq/llama-3.3-70b-versatile
+ provider_id: groq
+ provider_model_id: groq/llama-3.3-70b-versatile
+ model_type: llm
+- metadata: {}
+ model_id: groq/llama-4-scout-17b-16e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
+ model_type: llm
+- metadata: {}
+ model_id: groq/llama-4-maverick-17b-128e-instruct
+ provider_id: groq
+ provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
+ model_type: llm
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai