# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from typing import Any, AsyncIterator, Dict, List, Optional, Union from openai import AsyncOpenAI from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAIChoiceDelta, OpenAIChunkChoice, OpenAIMessageParam, OpenAIResponseFormatParam, OpenAISystemMessageParam, ) from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_compat import ( prepare_openai_completion_params, ) from .models import MODEL_ENTRIES class GroqInferenceAdapter(LiteLLMOpenAIMixin): _config: GroqConfig def __init__(self, config: GroqConfig): LiteLLMOpenAIMixin.__init__( self, model_entries=MODEL_ENTRIES, api_key_from_config=config.api_key, provider_data_api_key_field="groq_api_key", ) self.config = config self._openai_client = None async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() if self._openai_client: await self._openai_client.close() self._openai_client = None def _get_openai_client(self) -> AsyncOpenAI: if not self._openai_client: self._openai_client = AsyncOpenAI( base_url=f"{self.config.url}/openai/v1", api_key=self.config.api_key, ) return self._openai_client async def openai_chat_completion( self, model: str, messages: List[OpenAIMessageParam], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, logit_bias: Optional[Dict[str, float]] = None, logprobs: Optional[bool] = None, max_completion_tokens: Optional[int] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, stream_options: Optional[Dict[str, Any]] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) # Groq does not support json_schema response format, so we need to convert it to json_object if response_format and response_format.type == "json_schema": response_format.type = "json_object" schema = response_format.json_schema.get("schema", {}) response_format.json_schema = None json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}" if messages and messages[0].role == "system": messages[0].content = messages[0].content + json_instructions else: messages.insert(0, OpenAISystemMessageParam(content=json_instructions)) # Groq returns a 400 error if tools are provided but none are called # So, set tool_choice to "required" to attempt to force a call if tools and (not tool_choice or tool_choice == "auto"): tool_choice = "required" params = await prepare_openai_completion_params( model=model_obj.provider_resource_id.replace("groq/", ""), messages=messages, frequency_penalty=frequency_penalty, function_call=function_call, functions=functions, logit_bias=logit_bias, logprobs=logprobs, max_completion_tokens=max_completion_tokens, max_tokens=max_tokens, n=n, parallel_tool_calls=parallel_tool_calls, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, tool_choice=tool_choice, tools=tools, top_logprobs=top_logprobs, top_p=top_p, user=user, ) # Groq does not support streaming requests that set response_format fake_stream = False if stream and response_format: params["stream"] = False fake_stream = True response = await self._get_openai_client().chat.completions.create(**params) if fake_stream: chunk_choices = [] for choice in response.choices: delta = OpenAIChoiceDelta( content=choice.message.content, role=choice.message.role, tool_calls=choice.message.tool_calls, ) chunk_choice = OpenAIChunkChoice( delta=delta, finish_reason=choice.finish_reason, index=choice.index, logprobs=None, ) chunk_choices.append(chunk_choice) chunk = OpenAIChatCompletionChunk( id=response.id, choices=chunk_choices, object="chat.completion.chunk", created=response.created, model=response.model, ) async def _fake_stream_generator(): yield chunk return _fake_stream_generator() else: return response