From d23607483fc8ca63ea15db5a95f672aec249d2e1 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 18:36:27 -0400 Subject: [PATCH] chore: update the groq inference impl to use openai-python for openai-compat functions (#3348) # What does this PR do? update Groq inference provider to use OpenAIMixin for openai-compat endpoints changes on api.groq.com - - json_schema is now supported for specific models, see https://console.groq.com/docs/structured-outputs#supported-models - response_format with streaming is now supported for models that support response_format - groq no longer returns a 400 error if tools are provided and tool_choice is not "required" ## Test Plan ``` $ GROQ_API_KEY=... uv run llama stack build --image-type venv --providers inference=remote::groq --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model groq/llama-3.3-70b-versatile tests/integration/inference/test_openai_completion.py -k 'not store' ... SKIPPED [3] tests/integration/inference/test_openai_completion.py:44: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support OpenAI completions. SKIPPED [3] tests/integration/inference/test_openai_completion.py:94: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support vllm extra_body parameters. SKIPPED [4] tests/integration/inference/test_openai_completion.py:73: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support n param. SKIPPED [1] tests/integration/inference/test_openai_completion.py:100: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support chat completion calls with base64 encoded files. ======================= 8 passed, 11 skipped, 8 deselected, 2 warnings in 5.13s ======================== ``` --------- Co-authored-by: raghotham --- llama_stack/providers/registry/inference.py | 2 +- .../providers/remote/inference/groq/groq.py | 139 +----------------- .../test_inference_client_caching.py | 3 +- 3 files changed, 10 insertions(+), 134 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 7a95fd089..4176f85a6 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -248,7 +248,7 @@ Available Models: api=Api.inference, adapter=AdapterSpec( adapter_type="groq", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.groq", config_class="llama_stack.providers.remote.inference.groq.GroqConfig", provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index fd7212de4..888953af0 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -4,30 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncIterator -from typing import Any -from openai import AsyncOpenAI - -from llama_stack.apis.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChoiceDelta, - OpenAIChunkChoice, - OpenAIMessageParam, - OpenAIResponseFormatParam, - OpenAISystemMessageParam, -) from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -class GroqInferenceAdapter(LiteLLMOpenAIMixin): +class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): _config: GroqConfig def __init__(self, config: GroqConfig): @@ -40,122 +25,14 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + return f"{self.config.url}/openai/v1" + async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI( - base_url=f"{self.config.url}/openai/v1", - api_key=self.get_api_key(), - ) - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - - # Groq does not support json_schema response format, so we need to convert it to json_object - if response_format and response_format.type == "json_schema": - response_format.type = "json_object" - schema = response_format.json_schema.get("schema", {}) - response_format.json_schema = None - json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}" - if messages and messages[0].role == "system": - messages[0].content = messages[0].content + json_instructions - else: - messages.insert(0, OpenAISystemMessageParam(content=json_instructions)) - - # Groq returns a 400 error if tools are provided but none are called - # So, set tool_choice to "required" to attempt to force a call - if tools and (not tool_choice or tool_choice == "auto"): - tool_choice = "required" - - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - # Groq does not support streaming requests that set response_format - fake_stream = False - if stream and response_format: - params["stream"] = False - fake_stream = True - - response = await self._get_openai_client().chat.completions.create(**params) - - if fake_stream: - chunk_choices = [] - for choice in response.choices: - delta = OpenAIChoiceDelta( - content=choice.message.content, - role=choice.message.role, - tool_calls=choice.message.tool_calls, - ) - chunk_choice = OpenAIChunkChoice( - delta=delta, - finish_reason=choice.finish_reason, - index=choice.index, - logprobs=None, - ) - chunk_choices.append(chunk_choice) - chunk = OpenAIChatCompletionChunk( - id=response.id, - choices=chunk_choices, - object="chat.completion.chunk", - created=response.created, - model=response.model, - ) - - async def _fake_stream_generator(): - yield chunk - - return _fake_stream_generator() - else: - return response diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index b371cf907..f4b3201e9 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching(): with request_provider_data_context( {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): - openai_client = inference_adapter._get_openai_client() - assert openai_client.api_key == api_key + assert inference_adapter.client.api_key == api_key def test_openai_provider_openai_client_caching():