mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? * Given that our API packages use "import *" in `__init.py__` we don't need to do `from llama_stack.apis.models.models` but simply from llama_stack.apis.models. The decision to use `import *` is debatable and should probably be revisited at one point. * Remove unneeded Ruff F401 rule * Consolidate Ruff F403 rule in the pyprojectfrom llama_stack.apis.models.models Signed-off-by: Sébastien Han <seb@redhat.com>
166 lines
6 KiB
Python
166 lines
6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
from llama_stack.apis.inference import (
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChoiceDelta,
|
|
OpenAIChunkChoice,
|
|
OpenAIMessageParam,
|
|
OpenAIResponseFormatParam,
|
|
OpenAISystemMessageParam,
|
|
)
|
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
prepare_openai_completion_params,
|
|
)
|
|
|
|
from .models import MODEL_ENTRIES
|
|
|
|
|
|
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|
_config: GroqConfig
|
|
|
|
def __init__(self, config: GroqConfig):
|
|
LiteLLMOpenAIMixin.__init__(
|
|
self,
|
|
model_entries=MODEL_ENTRIES,
|
|
api_key_from_config=config.api_key,
|
|
provider_data_api_key_field="groq_api_key",
|
|
)
|
|
self.config = config
|
|
self._openai_client = None
|
|
|
|
async def initialize(self):
|
|
await super().initialize()
|
|
|
|
async def shutdown(self):
|
|
await super().shutdown()
|
|
if self._openai_client:
|
|
await self._openai_client.close()
|
|
self._openai_client = None
|
|
|
|
def _get_openai_client(self) -> AsyncOpenAI:
|
|
if not self._openai_client:
|
|
self._openai_client = AsyncOpenAI(
|
|
base_url=f"{self.config.url}/openai/v1",
|
|
api_key=self.config.api_key,
|
|
)
|
|
return self._openai_client
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: list[OpenAIMessageParam],
|
|
frequency_penalty: float | None = None,
|
|
function_call: str | dict[str, Any] | None = None,
|
|
functions: list[dict[str, Any]] | None = None,
|
|
logit_bias: dict[str, float] | None = None,
|
|
logprobs: bool | None = None,
|
|
max_completion_tokens: int | None = None,
|
|
max_tokens: int | None = None,
|
|
n: int | None = None,
|
|
parallel_tool_calls: bool | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: OpenAIResponseFormatParam | None = None,
|
|
seed: int | None = None,
|
|
stop: str | list[str] | None = None,
|
|
stream: bool | None = None,
|
|
stream_options: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
tools: list[dict[str, Any]] | None = None,
|
|
top_logprobs: int | None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
model_obj = await self.model_store.get_model(model)
|
|
|
|
# Groq does not support json_schema response format, so we need to convert it to json_object
|
|
if response_format and response_format.type == "json_schema":
|
|
response_format.type = "json_object"
|
|
schema = response_format.json_schema.get("schema", {})
|
|
response_format.json_schema = None
|
|
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
|
if messages and messages[0].role == "system":
|
|
messages[0].content = messages[0].content + json_instructions
|
|
else:
|
|
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
|
|
|
# Groq returns a 400 error if tools are provided but none are called
|
|
# So, set tool_choice to "required" to attempt to force a call
|
|
if tools and (not tool_choice or tool_choice == "auto"):
|
|
tool_choice = "required"
|
|
|
|
params = await prepare_openai_completion_params(
|
|
model=model_obj.provider_resource_id.replace("groq/", ""),
|
|
messages=messages,
|
|
frequency_penalty=frequency_penalty,
|
|
function_call=function_call,
|
|
functions=functions,
|
|
logit_bias=logit_bias,
|
|
logprobs=logprobs,
|
|
max_completion_tokens=max_completion_tokens,
|
|
max_tokens=max_tokens,
|
|
n=n,
|
|
parallel_tool_calls=parallel_tool_calls,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
stream=stream,
|
|
stream_options=stream_options,
|
|
temperature=temperature,
|
|
tool_choice=tool_choice,
|
|
tools=tools,
|
|
top_logprobs=top_logprobs,
|
|
top_p=top_p,
|
|
user=user,
|
|
)
|
|
|
|
# Groq does not support streaming requests that set response_format
|
|
fake_stream = False
|
|
if stream and response_format:
|
|
params["stream"] = False
|
|
fake_stream = True
|
|
|
|
response = await self._get_openai_client().chat.completions.create(**params)
|
|
|
|
if fake_stream:
|
|
chunk_choices = []
|
|
for choice in response.choices:
|
|
delta = OpenAIChoiceDelta(
|
|
content=choice.message.content,
|
|
role=choice.message.role,
|
|
tool_calls=choice.message.tool_calls,
|
|
)
|
|
chunk_choice = OpenAIChunkChoice(
|
|
delta=delta,
|
|
finish_reason=choice.finish_reason,
|
|
index=choice.index,
|
|
logprobs=None,
|
|
)
|
|
chunk_choices.append(chunk_choice)
|
|
chunk = OpenAIChatCompletionChunk(
|
|
id=response.id,
|
|
choices=chunk_choices,
|
|
object="chat.completion.chunk",
|
|
created=response.created,
|
|
model=response.model,
|
|
)
|
|
|
|
async def _fake_stream_generator():
|
|
yield chunk
|
|
|
|
return _fake_stream_generator()
|
|
else:
|
|
return response
|