diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index b0ee69bed2..37620ecfa5 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -2,10 +2,14 @@ Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions` """ -from typing import List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union from pydantic import BaseModel +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.openai.chat.gpt_transformation import ( + OpenAIChatCompletionStreamingHandler, +) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import ( AllMessageValues, @@ -13,10 +17,15 @@ from litellm.types.llms.openai import ( ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, ) +from litellm.types.utils import ModelResponse, ModelResponseStream from ...openai.chat.gpt_transformation import OpenAIGPTConfig +class GroqError(BaseLLMException): + pass + + class GroqChatConfig(OpenAIGPTConfig): frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None @@ -155,3 +164,39 @@ class GroqChatConfig(OpenAIGPTConfig): return super().map_openai_params( non_default_params, optional_params, model, drop_params ) + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return GroqChatCompletionStreamingHandler( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class GroqChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler): + def chunk_parser(self, chunk: dict) -> ModelResponseStream: + try: + ## HANDLE ERROR IN CHUNK ## + if "error" in chunk: + error_chunk = chunk["error"] + raise GroqError( + message="{}, Failed generation: {}".format( + error_chunk["message"], error_chunk["failed_generation"] + ), + status_code=error_chunk["status_code"], + ) + + return super().chunk_parser(chunk) + except KeyError as e: + raise GroqError( + message=f"KeyError: {e}, Got unexpected response from Groq: {chunk}", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + except Exception as e: + raise e