fix(groq/chat/transformation.py): handle max_retries

This commit is contained in:
Krrish Dholakia 2025-04-15 21:58:17 -07:00
parent 1fc9de1928
commit e0bc837957

View file

@ -2,14 +2,10 @@
Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions`
"""
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import (
OpenAIChatCompletionStreamingHandler,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
@ -17,15 +13,10 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ModelResponse, ModelResponseStream
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class GroqError(BaseLLMException):
pass
class GroqChatConfig(OpenAIGPTConfig):
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
@ -66,6 +57,14 @@ class GroqChatConfig(OpenAIGPTConfig):
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
base_params = super().get_supported_openai_params(model)
try:
base_params.remove("max_retries")
except ValueError:
pass
return base_params
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
for idx, message in enumerate(messages):
"""
@ -164,39 +163,3 @@ class GroqChatConfig(OpenAIGPTConfig):
return super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return GroqChatCompletionStreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class GroqChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
try:
## HANDLE ERROR IN CHUNK ##
if "error" in chunk:
error_chunk = chunk["error"]
raise GroqError(
message="{}, Failed generation: {}".format(
error_chunk["message"], error_chunk["failed_generation"]
),
status_code=error_chunk["status_code"],
)
return super().chunk_parser(chunk)
except KeyError as e:
raise GroqError(
message=f"KeyError: {e}, Got unexpected response from Groq: {chunk}",
status_code=400,
headers={"Content-Type": "application/json"},
)
except Exception as e:
raise e