mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: update the groq inference impl to use openai-python for openai-compat functions (#3348)
# What does this PR do? update Groq inference provider to use OpenAIMixin for openai-compat endpoints changes on api.groq.com - - json_schema is now supported for specific models, see https://console.groq.com/docs/structured-outputs#supported-models - response_format with streaming is now supported for models that support response_format - groq no longer returns a 400 error if tools are provided and tool_choice is not "required" ## Test Plan ``` $ GROQ_API_KEY=... uv run llama stack build --image-type venv --providers inference=remote::groq --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model groq/llama-3.3-70b-versatile tests/integration/inference/test_openai_completion.py -k 'not store' ... SKIPPED [3] tests/integration/inference/test_openai_completion.py:44: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support OpenAI completions. SKIPPED [3] tests/integration/inference/test_openai_completion.py:94: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support vllm extra_body parameters. SKIPPED [4] tests/integration/inference/test_openai_completion.py:73: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support n param. SKIPPED [1] tests/integration/inference/test_openai_completion.py💯 Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support chat completion calls with base64 encoded files. ======================= 8 passed, 11 skipped, 8 deselected, 2 warnings in 5.13s ======================== ``` --------- Co-authored-by: raghotham <rsm@meta.com>
This commit is contained in:
parent
ecd9d8dc1a
commit
d23607483f
3 changed files with 10 additions and 134 deletions
|
@ -248,7 +248,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="groq",
|
adapter_type="groq",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.groq",
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
|
|
|
@ -4,30 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAIChoiceDelta,
|
|
||||||
OpenAIChunkChoice,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
OpenAISystemMessageParam,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
prepare_openai_completion_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
_config: GroqConfig
|
_config: GroqConfig
|
||||||
|
|
||||||
def __init__(self, config: GroqConfig):
|
def __init__(self, config: GroqConfig):
|
||||||
|
@ -40,122 +25,14 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
return f"{self.config.url}/openai/v1"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
|
||||||
return AsyncOpenAI(
|
|
||||||
base_url=f"{self.config.url}/openai/v1",
|
|
||||||
api_key=self.get_api_key(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
|
||||||
if response_format and response_format.type == "json_schema":
|
|
||||||
response_format.type = "json_object"
|
|
||||||
schema = response_format.json_schema.get("schema", {})
|
|
||||||
response_format.json_schema = None
|
|
||||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
|
||||||
if messages and messages[0].role == "system":
|
|
||||||
messages[0].content = messages[0].content + json_instructions
|
|
||||||
else:
|
|
||||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
|
||||||
|
|
||||||
# Groq returns a 400 error if tools are provided but none are called
|
|
||||||
# So, set tool_choice to "required" to attempt to force a call
|
|
||||||
if tools and (not tool_choice or tool_choice == "auto"):
|
|
||||||
tool_choice = "required"
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Groq does not support streaming requests that set response_format
|
|
||||||
fake_stream = False
|
|
||||||
if stream and response_format:
|
|
||||||
params["stream"] = False
|
|
||||||
fake_stream = True
|
|
||||||
|
|
||||||
response = await self._get_openai_client().chat.completions.create(**params)
|
|
||||||
|
|
||||||
if fake_stream:
|
|
||||||
chunk_choices = []
|
|
||||||
for choice in response.choices:
|
|
||||||
delta = OpenAIChoiceDelta(
|
|
||||||
content=choice.message.content,
|
|
||||||
role=choice.message.role,
|
|
||||||
tool_calls=choice.message.tool_calls,
|
|
||||||
)
|
|
||||||
chunk_choice = OpenAIChunkChoice(
|
|
||||||
delta=delta,
|
|
||||||
finish_reason=choice.finish_reason,
|
|
||||||
index=choice.index,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
chunk_choices.append(chunk_choice)
|
|
||||||
chunk = OpenAIChatCompletionChunk(
|
|
||||||
id=response.id,
|
|
||||||
choices=chunk_choices,
|
|
||||||
object="chat.completion.chunk",
|
|
||||||
created=response.created,
|
|
||||||
model=response.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fake_stream_generator():
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return _fake_stream_generator()
|
|
||||||
else:
|
|
||||||
return response
|
|
||||||
|
|
|
@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
openai_client = inference_adapter._get_openai_client()
|
assert inference_adapter.client.api_key == api_key
|
||||||
assert openai_client.api_key == api_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_openai_provider_openai_client_caching():
|
def test_openai_provider_openai_client_caching():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue