mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7165
Closes https://github.com/BerriAI/litellm/pull/7165
This commit is contained in:
parent
b79db3616c
commit
6493eaf2ee
8 changed files with 255 additions and 176 deletions
|
@ -1072,7 +1072,7 @@ from .llms.groq.stt.transformation import GroqSTTConfig
|
|||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||
from .llms.predibase import PredibaseConfig
|
||||
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
|
|
|
@ -200,4 +200,6 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "predibase":
|
||||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||
return None
|
||||
|
|
|
@ -25,36 +25,9 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class PredibaseError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request is not None:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||
)
|
||||
if response is not None:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from ..common_utils import PredibaseError
|
||||
|
||||
|
||||
async def make_call(
|
||||
|
@ -86,143 +59,10 @@ async def make_call(
|
|||
return completion_stream
|
||||
|
||||
|
||||
class PredibaseConfig:
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
|
||||
"""
|
||||
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: bool = True # enables returning logprobs + best of
|
||||
max_new_tokens: int = (
|
||||
256 # openai default - requests hang if max_new_tokens not given
|
||||
)
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = (
|
||||
False # by default don't return the input as part of the output
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
truncate: Optional[int] = None
|
||||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class PredibaseChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(
|
||||
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
if user_headers is not None and isinstance(user_headers, dict):
|
||||
headers = {**headers, **user_headers}
|
||||
return headers
|
||||
|
||||
def output_parser(self, generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
@ -398,7 +238,13 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
|
||||
headers = litellm.PredibaseConfig().validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
185
litellm/llms/predibase/chat/transformation.py
Normal file
185
litellm/llms/predibase/chat/transformation.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..common_utils import PredibaseError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class PredibaseConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
"""
|
||||
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: bool = True # enables returning logprobs + best of
|
||||
max_new_tokens: int = (
|
||||
256 # openai default - requests hang if max_new_tokens not given
|
||||
)
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = (
|
||||
False # by default don't return the input as part of the output
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
truncate: Optional[int] = None
|
||||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
return messages
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return PredibaseError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
if headers is not None and isinstance(headers, dict):
|
||||
headers = {**default_headers, **headers}
|
||||
return headers
|
23
litellm/llms/predibase/common_utils.py
Normal file
23
litellm/llms/predibase/common_utils.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
|
||||
|
||||
class PredibaseError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=request,
|
||||
response=response,
|
||||
headers=headers,
|
||||
)
|
|
@ -104,13 +104,12 @@ from .llms.azure_ai.embed import AzureAIEmbedding
|
|||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.codestral.completion.handler import CodestralTextCompletion
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
from .llms.deprecated_providers import palm, aleph_alpha
|
||||
from .llms.deprecated_providers import aleph_alpha, palm
|
||||
from .llms.groq.chat.handler import GroqChatCompletion
|
||||
from .llms.huggingface.chat.handler import Huggingface
|
||||
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||
|
@ -121,10 +120,11 @@ from .llms.openai.openai import OpenAIChatCompletion
|
|||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
|
@ -2328,6 +2328,11 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or get_secret("PREDIBASE_TENANT_ID")
|
||||
)
|
||||
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop("api_base", None)
|
||||
|
|
|
@ -369,8 +369,14 @@ def function_setup( # noqa: PLR0915
|
|||
litellm._async_success_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
elif callback in litellm._known_custom_logger_compatible_callbacks:
|
||||
callback_class = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback, internal_usage_cache=None, llm_router=None # type: ignore
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
callback_class = _init_custom_logger_compatible_class(
|
||||
callback, # type: ignore
|
||||
internal_usage_cache=None,
|
||||
llm_router=None, # type: ignore
|
||||
)
|
||||
|
||||
# don't double add a callback
|
||||
|
@ -2941,7 +2947,14 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.PredibaseConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
## check if unsupported param passed in
|
||||
|
@ -4102,7 +4115,6 @@ def get_api_base(
|
|||
partner=VertexPartnerProvider.claude,
|
||||
)
|
||||
else:
|
||||
|
||||
if stream:
|
||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||
_optional_params.vertex_location,
|
||||
|
@ -6359,6 +6371,8 @@ class ProviderConfigManager:
|
|||
return litellm.VLLMConfig()
|
||||
elif litellm.LlmProviders.OLLAMA == provider:
|
||||
return litellm.OllamaConfig()
|
||||
elif litellm.LlmProviders.PREDIBASE == provider:
|
||||
return litellm.PredibaseConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
|
|
|
@ -204,12 +204,16 @@ def test_all_model_configs():
|
|||
drop_params=False,
|
||||
) == {"num_predict": 10}
|
||||
|
||||
from litellm.llms.predibase import PredibaseConfig
|
||||
from litellm.llms.predibase.chat.transformation import PredibaseConfig
|
||||
|
||||
assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params()
|
||||
assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params(
|
||||
model="llama3"
|
||||
)
|
||||
assert PredibaseConfig().map_openai_params(
|
||||
{"max_completion_tokens": 10},
|
||||
{},
|
||||
model="llama3",
|
||||
non_default_params={"max_completion_tokens": 10},
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
) == {"max_new_tokens": 10}
|
||||
|
||||
from litellm.llms.codestral.completion.transformation import (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue