feat(providers): Groq now uses LiteLLM openai-compat (#1303)

Groq has never supported raw completions anyhow. So this makes it easier
to switch it to LiteLLM. All our test suite passes.

I also updated all the openai-compat providers so they work with api
keys passed from headers. `provider_data`

## Test Plan

```bash
LLAMA_STACK_CONFIG=groq \
   pytest -s -v tests/client-sdk/inference/test_text_inference.py \
   --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model=""
```

Also tested (openai, anthropic, gemini) providers. No regressions.
This commit is contained in:
Ashwin Bharambe 2025-02-27 13:16:50 -08:00 committed by GitHub
parent 564f0e5f93
commit 928a39d17b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 165 additions and 1004 deletions

View file

@ -146,6 +146,7 @@
"fastapi",
"fire",
"fireworks-ai",
"groq",
"httpx",
"litellm",
"matplotlib",

View file

@ -37,11 +37,11 @@ The following environment variables can be configured:
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)`
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)`
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)`
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)`
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)`
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama3-8b-8192)`
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama-3.1-8b-instant)`
- `meta-llama/Llama-3-70B-Instruct (groq/llama3-70b-8192)`
- `meta-llama/Llama-3.3-70B-Instruct (groq/llama-3.3-70b-versatile)`
- `meta-llama/Llama-3.2-3B-Instruct (groq/llama-3.2-3b-preview)`
### Prerequisite: API Keys

View file

@ -157,16 +157,6 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
@ -214,6 +204,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
),
),
remote_provider_spec(
@ -223,6 +214,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
),
),
remote_provider_spec(
@ -232,6 +224,17 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
),
remote_provider_spec(

View file

@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
pass
await super().initialize()
async def shutdown(self) -> None:
pass
await super().shutdown()

View file

@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: Optional[str] = Field(
default=None,
description="API key for Anthropic models",
)
@json_schema_type
class AnthropicConfig(BaseModel):
api_key: Optional[str] = Field(

View file

@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: Optional[str] = Field(
default=None,
description="API key for Gemini models",
)
@json_schema_type
class GeminiConfig(BaseModel):
api_key: Optional[str] = Field(

View file

@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
async def initialize(self) -> None:
pass
await super().initialize()
async def shutdown(self) -> None:
pass
await super().shutdown()

View file

@ -4,23 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.inference import Inference
from .config import GroqConfig
class GroqProviderDataValidator(BaseModel):
groq_api_key: str
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
if not isinstance(config, GroqConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = GroqInferenceAdapter(config)
return adapter

View file

@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: Optional[str] = Field(
default=None,
description="API key for Groq models",
)
@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
@ -25,8 +32,8 @@ class GroqConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": "${env.GROQ_API_KEY}",
"api_key": api_key,
}

View file

@ -4,130 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncIterator, List, Optional, Union
import groq
from groq import Groq
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
from .models import _MODEL_ENTRIES
from .models import MODEL_ENTRIES
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
_config: GroqConfig
def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self._config = config
def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
"More details: https://console.groq.com/docs/models"
)
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config
try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
else:
raise e
async def initialize(self):
await super().initialize()
if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
def _get_client(self) -> Groq:
if self._config.api_key is not None:
return Groq(api_key=self._config.api_key)
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key:
raise ValueError(
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)
async def shutdown(self):
await super().shutdown()

View file

@ -1,245 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import AsyncGenerator, Literal
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_tool_call,
get_sampling_strategy_options,
)
def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
"""
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
Warns client if request contains unsupported features.
"""
if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")
if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")
if request.sampling_params.repetition_penalty != 1.0:
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
# seem to have different semantics
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
logprobs=None,
frequency_penalty=None,
stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None,
temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
)
def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == "system":
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.role == "user":
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == "assistant":
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)
else:
raise ValueError(f"Invalid message role: {message.role}")
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
raise AssertionError("tool_definition.description is required")
tool_parameters = tool_definition.parameters or {}
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()},
),
)
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
if tool_parameter.description is not None:
param["description"] = tool_parameter.description
if tool_parameter.required is not None:
param["required"] = tool_parameter.required
if tool_parameter.default is not None:
param["default"] = tool_parameter.default
return param
def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
completion_message=CompletionMessage(
stop_reason=StopReason.end_of_message,
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
)
else:
# Otherwise, return tool calls as normal
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls"]
- stop -> model hit a natural stop point or a provided stop sequence
- length -> maximum number of tokens specified in the request was reached
- tool_calls -> model called a tool
"""
if finish_reason == "stop":
return StopReason.end_of_turn
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
return StopReason.end_of_message
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")
async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
event_type = ChatCompletionResponseEventType.start
for chunk in stream:
choice = chunk.choices[0]
if choice.finish_reason:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
# We assume Groq produces fully formed tool calls for each chunk
tool_call = convert_tool_call(choice.delta.tool_calls[0])
if isinstance(tool_call, ToolCall):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
)
)
else:
# Otherwise it's an UnparseableToolCall - return the raw tool call
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=tool_call.model_dump_json(),
parse_status=ToolCallParseStatus.failed,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress

View file

@ -7,21 +7,21 @@
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_model_entry
_MODEL_ENTRIES = [
MODEL_ENTRIES = [
build_model_entry(
"llama3-8b-8192",
"groq/llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
"groq/llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3-70b-8192",
"groq/llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_entry(
"llama-3.3-70b-versatile",
"groq/llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
@ -29,7 +29,7 @@ _MODEL_ENTRIES = [
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_entry(
"llama-3.2-3b-preview",
"groq/llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]

View file

@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="API key for OpenAI models",
)
@json_schema_type
class OpenAIConfig(BaseModel):
api_key: Optional[str] = Field(

View file

@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
self.config = config
async def initialize(self) -> None:
pass
await super().initialize()
async def shutdown(self) -> None:
pass
await super().shutdown()

View file

@ -1,575 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from groq.types.chat.chat_completion_chunk import (
Choice as StreamChoice,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from groq.types.shared.function_definition import FunctionDefinition
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy
from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
class TestConvertChatCompletionRequest:
def test_sets_model(self):
request = self._dummy_chat_completion_request()
request.model = "Llama-3.2-3B"
converted = convert_chat_completion_request(request)
assert converted["model"] == "Llama-3.2-3B"
def test_converts_user_message(self):
request = self._dummy_chat_completion_request()
request.messages = [UserMessage(content="Hello World")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
]
def test_converts_system_message(self):
request = self._dummy_chat_completion_request()
request.messages = [SystemMessage(content="You are a helpful assistant.")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "system", "content": "You are a helpful assistant."},
]
def test_converts_completion_message(self):
request = self._dummy_chat_completion_request()
request.messages = [
UserMessage(content="Hello World"),
CompletionMessage(
content="Hello World! How can I help you today?",
stop_reason=StopReason.end_of_message,
),
]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
{"role": "assistant", "content": "Hello World! How can I help you today?"},
]
def test_does_not_include_logprobs(self):
request = self._dummy_chat_completion_request()
request.logprobs = True
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "logprobs are not supported yet" in warnings[0].message.args[0]
assert converted.get("logprobs") is None
def test_does_not_include_response_format(self):
request = self._dummy_chat_completion_request()
request.response_format = {
"type": "json_object",
"json_schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
},
}
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "response_format is not supported yet" in warnings[0].message.args[0]
assert converted.get("response_format") is None
def test_does_not_include_repetition_penalty(self):
request = self._dummy_chat_completion_request()
request.sampling_params.repetition_penalty = 1.5
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
assert converted.get("repetition_penalty") is None
assert converted.get("frequency_penalty") is None
def test_includes_stream(self):
request = self._dummy_chat_completion_request()
request.stream = True
converted = convert_chat_completion_request(request)
assert converted["stream"] is True
def test_if_max_tokens_is_0_then_it_is_not_included(self):
request = self._dummy_chat_completion_request()
# 0 is the default value for max_tokens
# So we assume that if it's 0, the user didn't set it
request.sampling_params.max_tokens = 0
converted = convert_chat_completion_request(request)
assert converted.get("max_tokens") is None
def test_includes_max_tokens_if_set(self):
request = self._dummy_chat_completion_request()
request.sampling_params.max_tokens = 100
converted = convert_chat_completion_request(request)
assert converted["max_tokens"] == 100
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
def test_includes_stratgy(self):
request = self._dummy_chat_completion_request()
request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95)
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5
assert converted["top_p"] == 0.95
def test_includes_greedy_strategy(self):
request = self._dummy_chat_completion_request()
request.sampling_params.strategy = GreedySamplingStrategy()
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.0
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
request.tool_config.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)
assert converted["tool_choice"] == "required"
def test_includes_tools(self):
request = self._dummy_chat_completion_request()
request.tools = [
ToolDefinition(
tool_name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": ToolParamDefinition(
param_type="string",
description="The origin airport code. E.g., AU",
required=True,
),
"destination": ToolParamDefinition(
param_type="string",
description="The destination airport code. E.g., 'LAX'",
required=True,
),
"passengers": ToolParamDefinition(
param_type="array",
description="The passengers",
required=False,
),
},
),
ToolDefinition(
tool_name="log",
description="Calulate the logarithm of a number",
parameters={
"number": ToolParamDefinition(
param_type="float",
description="The number to calculate the logarithm of",
required=True,
),
"base": ToolParamDefinition(
param_type="integer",
description="The base of the logarithm",
required=False,
default=10,
),
},
),
]
converted = convert_chat_completion_request(request)
assert converted["tools"] == [
{
"type": "function",
"function": FunctionDefinition(
name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": {
"type": "string",
"description": "The origin airport code. E.g., AU",
"required": True,
},
"destination": {
"type": "string",
"description": "The destination airport code. E.g., 'LAX'",
"required": True,
},
"passengers": {
"type": "array",
"description": "The passengers",
"required": False,
},
},
),
},
{
"type": "function",
"function": FunctionDefinition(
name="log",
description="Calulate the logarithm of a number",
parameters={
"number": {
"type": "float",
"description": "The number to calculate the logarithm of",
"required": True,
},
"base": {
"type": "integer",
"description": "The base of the logarithm",
"required": False,
"default": 10,
},
},
),
},
]
class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self):
response = self._dummy_chat_completion_response()
response.choices[0].message.content = "Hello World"
converted = convert_chat_completion_response(response)
assert converted.completion_message.content == "Hello World"
def test_maps_stop_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "stop"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_turn
def test_maps_length_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "length"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
def test_maps_tool_call_to_end_of_message(self):
response = self._dummy_chat_completion_response_with_tool_call()
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_message
def test_converts_multiple_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
),
ChatCompletionMessageToolCall(
id="tool_call_id_2",
type="function",
function=Function(
name="log",
arguments='{"number": 10, "base": 2}',
),
),
]
converted = convert_chat_completion_response(response)
assert converted.completion_message.tool_calls == [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
def test_converts_unparseable_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="log",
arguments="(number=10, base=2)",
),
),
]
converted = convert_chat_completion_response(response)
assert (
converted.completion_message.content
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
)
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello World"),
finish_reason="stop",
)
],
created=1729382400,
object="chat.completion",
)
def _dummy_chat_completion_response_with_tool_call(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
finish_reason="tool_calls",
)
],
created=1729382400,
object="chat.completion",
)
class TestConvertStreamChatCompletionResponse:
@pytest.mark.asyncio
async def test_returns_stream(self):
def chat_completion_stream():
messages = ["Hello ", "World ", " !"]
for i, message in enumerate(messages):
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = message
yield chunk
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = chat_completion_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.text == "Hello "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta.text == "World "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta.text == " !"
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta.text == ""
assert chunk.event.stop_reason == StopReason.end_of_turn
with pytest.raises(StopAsyncIteration):
await iter.__anext__()
@pytest.mark.asyncio
async def test_returns_tool_calls_stream(self):
def tool_call_stream():
tool_calls = [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
for i, tool_call in enumerate(tool_calls):
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id=tool_call.call_id,
function=ChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
arguments=json.dumps(tool_call.arguments),
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.tool_call == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
)
@pytest.mark.asyncio
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
def tool_call_stream():
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id="tool_call_id",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments="(origin=AU, destination=LAX)",
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert (
chunk.event.delta.content
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
)
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(role="assistant", content="Hello World"),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)
def _dummy_chat_completion_chunk_with_tool_call(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(
role="assistant",
content="Hello World",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -49,10 +50,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
NeedsRequestProviderData,
):
def __init__(self, model_entries) -> None:
self.model_entries = model_entries
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
@ -144,8 +153,16 @@ class LiteLLMOpenAIMixin(
if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),

View file

@ -7,6 +7,7 @@ distribution_spec:
- remote::fireworks
- remote::anthropic
- remote::gemini
- remote::groq
- inline::sentence-transformers
vector_io:
- inline::sqlite-vec

View file

@ -24,6 +24,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -52,6 +54,11 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
GEMINI_MODEL_ENTRIES,
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
),
(
"groq",
GROQ_MODEL_ENTRIES,
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
),
]
inference_providers = []
default_models = []
@ -78,14 +85,9 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
def get_distribution_template() -> DistributionTemplate:
inference_providers, default_models = get_inference_providers()
providers = {
"inference": [
"remote::openai",
"remote::fireworks",
"remote::anthropic",
"remote::gemini",
"inline::sentence-transformers",
],
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
@ -136,7 +138,6 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
inference_providers, default_models = get_inference_providers()
return DistributionTemplate(
name=name,

View file

@ -29,6 +29,11 @@ providers:
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@ -241,6 +246,31 @@ models:
provider_id: gemini
provider_model_id: gemini/text-embedding-004
model_type: embedding
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: groq/llama3-8b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.1-8b-instant
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama3-70b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.2-3b-preview
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -16,9 +16,8 @@ from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.groq.models import _MODEL_ENTRIES
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -52,11 +51,6 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
@ -69,11 +63,13 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id=name,
model_type=m.model_type,
metadata=m.metadata,
)
for m in _MODEL_ENTRIES
for m in MODEL_ENTRIES
]
default_tool_groups = [

View file

@ -93,27 +93,27 @@ models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: llama3-8b-8192
provider_model_id: groq/llama3-8b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: llama-3.1-8b-instant
provider_model_id: groq/llama-3.1-8b-instant
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: groq
provider_model_id: llama3-70b-8192
provider_model_id: groq/llama3-70b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: llama-3.3-70b-versatile
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: groq
provider_model_id: llama-3.2-3b-preview
provider_model_id: groq/llama-3.2-3b-preview
model_type: llm
- metadata:
embedding_dimension: 384

View file

@ -117,7 +117,9 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = [m.identifier for m in client.models.list()]
model_ids = set(m.identifier for m in client.models.list())
model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id and vision_model_id not in model_ids:

View file

@ -18,7 +18,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")