mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
commitb12a9892b7
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Wed Apr 2 08:09:56 2025 -0700 fix(utils.py): don't modify openai_token_counter commit294de31803
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 21:22:40 2025 -0700 fix: fix linting error commitcb6e9fbe40
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 19:52:45 2025 -0700 refactor: complete migration commitbfc159172d
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 19:09:59 2025 -0700 refactor: refactor more constants commit43ffb6a558
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:45:24 2025 -0700 fix: test commit04dbe4310c
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:28:58 2025 -0700 refactor: refactor: move more constants into constants.py commit3c26284aff
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:14:46 2025 -0700 refactor: migrate hardcoded constants out of __init__.py commitc11e0de69d
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:11:21 2025 -0700 build: migrate all constants into constants.py commit7882bdc787
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:07:37 2025 -0700 build: initial test banning hardcoded numbers in repo
181 lines
6.6 KiB
Python
181 lines
6.6 KiB
Python
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
|
|
|
from httpx import Headers, Response
|
|
|
|
from litellm.constants import DEFAULT_MAX_TOKENS
|
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
from ..common_utils import PredibaseError
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
|
else:
|
|
LiteLLMLoggingObj = Any
|
|
|
|
|
|
class PredibaseConfig(BaseConfig):
|
|
"""
|
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
|
"""
|
|
|
|
adapter_id: Optional[str] = None
|
|
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
|
best_of: Optional[int] = None
|
|
decoder_input_details: Optional[bool] = None
|
|
details: bool = True # enables returning logprobs + best of
|
|
max_new_tokens: int = (
|
|
DEFAULT_MAX_TOKENS # openai default - requests hang if max_new_tokens not given
|
|
)
|
|
repetition_penalty: Optional[float] = None
|
|
return_full_text: Optional[
|
|
bool
|
|
] = False # by default don't return the input as part of the output
|
|
seed: Optional[int] = None
|
|
stop: Optional[List[str]] = None
|
|
temperature: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
truncate: Optional[int] = None
|
|
typical_p: Optional[float] = None
|
|
watermark: Optional[bool] = None
|
|
|
|
def __init__(
|
|
self,
|
|
best_of: Optional[int] = None,
|
|
decoder_input_details: Optional[bool] = None,
|
|
details: Optional[bool] = None,
|
|
max_new_tokens: Optional[int] = None,
|
|
repetition_penalty: Optional[float] = None,
|
|
return_full_text: Optional[bool] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[List[str]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
truncate: Optional[int] = None,
|
|
typical_p: Optional[float] = None,
|
|
watermark: Optional[bool] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return super().get_config()
|
|
|
|
def get_supported_openai_params(self, model: str):
|
|
return [
|
|
"stream",
|
|
"temperature",
|
|
"max_completion_tokens",
|
|
"max_tokens",
|
|
"top_p",
|
|
"stop",
|
|
"n",
|
|
"response_format",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for param, value in non_default_params.items():
|
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
|
if param == "temperature":
|
|
if value == 0.0 or value == 0:
|
|
# hugging face exception raised when temp==0
|
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
|
value = 0.01
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
if param == "n":
|
|
optional_params["best_of"] = value
|
|
optional_params[
|
|
"do_sample"
|
|
] = True # Need to sample if you want best of for hf inference endpoints
|
|
if param == "stream":
|
|
optional_params["stream"] = value
|
|
if param == "stop":
|
|
optional_params["stop"] = value
|
|
if param == "max_tokens" or param == "max_completion_tokens":
|
|
# HF TGI raises the following exception when max_new_tokens==0
|
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
|
if value == 0:
|
|
value = 1
|
|
optional_params["max_new_tokens"] = value
|
|
if param == "echo":
|
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
|
optional_params["decoder_input_details"] = True
|
|
if param == "response_format":
|
|
optional_params["response_format"] = value
|
|
return optional_params
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: str,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
raise NotImplementedError(
|
|
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
|
)
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
raise NotImplementedError(
|
|
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
|
)
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
|
) -> BaseLLMException:
|
|
return PredibaseError(
|
|
status_code=status_code, message=error_message, headers=headers
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
|
)
|
|
|
|
default_headers = {
|
|
"content-type": "application/json",
|
|
"Authorization": "Bearer {}".format(api_key),
|
|
}
|
|
if headers is not None and isinstance(headers, dict):
|
|
headers = {**default_headers, **headers}
|
|
return headers
|