mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Co-authored-by: Joey Feldberg <joeyfeldberg@users.noreply.github.com> Co-authored-by: Joey Feldberg <12495578+joeyfeldberg@users.noreply.github.com>
66 lines
2 KiB
Python
66 lines
2 KiB
Python
from typing import List, Optional, Union
|
|
|
|
from httpx._models import Headers
|
|
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
|
|
|
|
class MaritalkError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[dict, Headers]] = None,
|
|
):
|
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
|
|
|
|
class MaritalkConfig(OpenAIGPTConfig):
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[float] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
stop: Optional[List[str]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[dict] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return super().get_config()
|
|
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
return [
|
|
"frequency_penalty",
|
|
"presence_penalty",
|
|
"top_p",
|
|
"top_k",
|
|
"temperature",
|
|
"max_tokens",
|
|
"n",
|
|
"stop",
|
|
"stream",
|
|
"stream_options",
|
|
"tools",
|
|
"tool_choice",
|
|
]
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
|
) -> BaseLLMException:
|
|
return MaritalkError(
|
|
status_code=status_code, message=error_message, headers=headers
|
|
)
|