mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* remove unused imports * fix AmazonConverseConfig * fix test * fix import * ruff check fixes * test fixes * fix testing * fix imports
160 lines
4 KiB
Python
160 lines
4 KiB
Python
"""
|
|
Common base config for all LLM providers
|
|
"""
|
|
|
|
import types
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
|
else:
|
|
LiteLLMLoggingObj = Any
|
|
|
|
|
|
class BaseLLMException(Exception):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
|
request: Optional[httpx.Request] = None,
|
|
response: Optional[httpx.Response] = None,
|
|
):
|
|
self.status_code = status_code
|
|
self.message: str = message
|
|
self.headers = headers
|
|
if request:
|
|
self.request = request
|
|
else:
|
|
self.request = httpx.Request(
|
|
method="POST", url="https://docs.litellm.ai/docs"
|
|
)
|
|
if response:
|
|
self.response = response
|
|
else:
|
|
self.response = httpx.Response(
|
|
status_code=status_code, request=self.request
|
|
)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
|
|
class BaseConfig(ABC):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def should_fake_stream(
|
|
self,
|
|
model: str,
|
|
stream: Optional[bool],
|
|
custom_llm_provider: Optional[str] = None,
|
|
) -> bool:
|
|
"""
|
|
Returns True if the model/provider should fake stream
|
|
"""
|
|
return False
|
|
|
|
@abstractmethod
|
|
def get_supported_openai_params(self, model: str) -> list:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
) -> dict:
|
|
pass
|
|
|
|
def get_complete_url(self, api_base: str, model: str) -> str:
|
|
"""
|
|
OPTIONAL
|
|
|
|
Get the complete url for the request
|
|
|
|
Some providers need `model` in `api_base`
|
|
"""
|
|
return api_base
|
|
|
|
@abstractmethod
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
pass
|
|
|
|
def get_model_response_iterator(
|
|
self,
|
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
|
sync_stream: bool,
|
|
json_mode: Optional[bool] = False,
|
|
) -> Any:
|
|
pass
|