mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error
65 lines
2 KiB
Python
65 lines
2 KiB
Python
from typing import List, Optional, Union
|
|
|
|
from httpx._models import Headers
|
|
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
|
|
|
|
class MaritalkError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[dict, Headers]] = None,
|
|
):
|
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
|
|
|
|
class MaritalkConfig(OpenAIGPTConfig):
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[float] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
stop: Optional[List[str]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[dict] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return super().get_config()
|
|
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
return [
|
|
"frequency_penalty",
|
|
"presence_penalty",
|
|
"top_p",
|
|
"top_k",
|
|
"temperature",
|
|
"max_tokens",
|
|
"n",
|
|
"stop",
|
|
"stream",
|
|
"stream_options",
|
|
"tools",
|
|
"tool_choice",
|
|
]
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
|
) -> BaseLLMException:
|
|
return MaritalkError(
|
|
status_code=status_code, message=error_message, headers=headers
|
|
)
|