mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
fix(groq/chat/transformation.py): handle max_retries
This commit is contained in:
parent
1fc9de1928
commit
e0bc837957
1 changed files with 9 additions and 46 deletions
|
@ -2,14 +2,10 @@
|
|||
Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.chat.gpt_transformation import (
|
||||
OpenAIChatCompletionStreamingHandler,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -17,15 +13,10 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class GroqError(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class GroqChatConfig(OpenAIGPTConfig):
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
|
@ -66,6 +57,14 @@ class GroqChatConfig(OpenAIGPTConfig):
|
|||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = super().get_supported_openai_params(model)
|
||||
try:
|
||||
base_params.remove("max_retries")
|
||||
except ValueError:
|
||||
pass
|
||||
return base_params
|
||||
|
||||
def _transform_messages(self, messages: List[AllMessageValues], model: str) -> List:
|
||||
for idx, message in enumerate(messages):
|
||||
"""
|
||||
|
@ -164,39 +163,3 @@ class GroqChatConfig(OpenAIGPTConfig):
|
|||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return GroqChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class GroqChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
## HANDLE ERROR IN CHUNK ##
|
||||
if "error" in chunk:
|
||||
error_chunk = chunk["error"]
|
||||
raise GroqError(
|
||||
message="{}, Failed generation: {}".format(
|
||||
error_chunk["message"], error_chunk["failed_generation"]
|
||||
),
|
||||
status_code=error_chunk["status_code"],
|
||||
)
|
||||
|
||||
return super().chunk_parser(chunk)
|
||||
except KeyError as e:
|
||||
raise GroqError(
|
||||
message=f"KeyError: {e}, Got unexpected response from Groq: {chunk}",
|
||||
status_code=400,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue