diff --git a/litellm/__init__.py b/litellm/__init__.py index 4b63d1d1e0..f70f9d266e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1072,7 +1072,7 @@ from .llms.groq.stt.transformation import GroqSTTConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig -from .llms.predibase import PredibaseConfig +from .llms.predibase.chat.transformation import PredibaseConfig from .llms.replicate.chat.transformation import ReplicateConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.clarifai.chat.transformation import ClarifaiConfig diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 3bdb4e5a84..40753839d9 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -200,4 +200,6 @@ def get_supported_openai_params( # noqa: PLR0915 return litellm.OpenAITextCompletionConfig().get_supported_openai_params( model=model ) + elif custom_llm_provider == "predibase": + return litellm.PredibaseConfig().get_supported_openai_params(model=model) return None diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase/chat/handler.py similarity index 71% rename from litellm/llms/predibase.py rename to litellm/llms/predibase/chat/handler.py index 6669812464..467ebf4aad 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase/chat/handler.py @@ -25,36 +25,9 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage -from .base import BaseLLM -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory - - -class PredibaseError(Exception): - def __init__( - self, - status_code, - message, - request: Optional[httpx.Request] = None, - response: Optional[httpx.Response] = None, - ): - self.status_code = status_code - self.message = message - if request is not None: - self.request = request - else: - self.request = httpx.Request( - method="POST", - url="https://docs.predibase.com/user-guide/inference/rest_api", - ) - if response is not None: - self.response = response - else: - self.response = httpx.Response( - status_code=status_code, request=self.request - ) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs +from ...base import BaseLLM +from ...prompt_templates.factory import custom_prompt, prompt_factory +from ..common_utils import PredibaseError async def make_call( @@ -86,143 +59,10 @@ async def make_call( return completion_stream -class PredibaseConfig: - """ - Reference: https://docs.predibase.com/user-guide/inference/rest_api - - """ - - adapter_id: Optional[str] = None - adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None - best_of: Optional[int] = None - decoder_input_details: Optional[bool] = None - details: bool = True # enables returning logprobs + best of - max_new_tokens: int = ( - 256 # openai default - requests hang if max_new_tokens not given - ) - repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = ( - False # by default don't return the input as part of the output - ) - seed: Optional[int] = None - stop: Optional[List[str]] = None - temperature: Optional[float] = None - top_k: Optional[int] = None - top_p: Optional[int] = None - truncate: Optional[int] = None - typical_p: Optional[float] = None - watermark: Optional[bool] = None - - def __init__( - self, - best_of: Optional[int] = None, - decoder_input_details: Optional[bool] = None, - details: Optional[bool] = None, - max_new_tokens: Optional[int] = None, - repetition_penalty: Optional[float] = None, - return_full_text: Optional[bool] = None, - seed: Optional[int] = None, - stop: Optional[List[str]] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[int] = None, - truncate: Optional[int] = None, - typical_p: Optional[float] = None, - watermark: Optional[bool] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "stream", - "temperature", - "max_completion_tokens", - "max_tokens", - "top_p", - "stop", - "n", - "response_format", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None - if param == "temperature": - if value == 0.0 or value == 0: - # hugging face exception raised when temp==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive - value = 0.01 - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "n": - optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) - if param == "stream": - optional_params["stream"] = value - if param == "stop": - optional_params["stop"] = value - if param == "max_tokens" or param == "max_completion_tokens": - # HF TGI raises the following exception when max_new_tokens==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive - if value == 0: - value = 1 - optional_params["max_new_tokens"] = value - if param == "echo": - # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details - # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False - optional_params["decoder_input_details"] = True - if param == "response_format": - optional_params["response_format"] = value - return optional_params - - class PredibaseChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def _validate_environment( - self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str] - ) -> dict: - if api_key is None: - raise ValueError( - "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" - ) - if tenant_id is None: - raise ValueError( - "Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=)`) or in env - `PREDIBASE_TENANT_ID`." - ) - headers = { - "content-type": "application/json", - "Authorization": "Bearer {}".format(api_key), - } - if user_headers is not None and isinstance(user_headers, dict): - headers = {**headers, **user_headers} - return headers - def output_parser(self, generated_text: str): """ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. @@ -398,7 +238,13 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: - headers = self._validate_environment(api_key, headers, tenant_id=tenant_id) + headers = litellm.PredibaseConfig().validate_environment( + api_key=api_key, + headers=headers, + messages=messages, + optional_params=optional_params, + model=model, + ) completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py new file mode 100644 index 0000000000..c6b9451dd3 --- /dev/null +++ b/litellm/llms/predibase/chat/transformation.py @@ -0,0 +1,185 @@ +import types +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union + +from httpx import Headers, Response + +from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +from ..common_utils import PredibaseError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class PredibaseConfig(BaseConfig): + """ + Reference: https://docs.predibase.com/user-guide/inference/rest_api + """ + + adapter_id: Optional[str] = None + adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None + best_of: Optional[int] = None + decoder_input_details: Optional[bool] = None + details: bool = True # enables returning logprobs + best of + max_new_tokens: int = ( + 256 # openai default - requests hang if max_new_tokens not given + ) + repetition_penalty: Optional[float] = None + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) + seed: Optional[int] = None + stop: Optional[List[str]] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + truncate: Optional[int] = None + typical_p: Optional[float] = None + watermark: Optional[bool] = None + + def __init__( + self, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + details: Optional[bool] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str): + return [ + "stream", + "temperature", + "max_completion_tokens", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens" or param == "max_completion_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def _transform_messages( + self, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + return messages + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return PredibaseError( + status_code=status_code, message=error_message, headers=headers + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" + ) + + default_headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + if headers is not None and isinstance(headers, dict): + headers = {**default_headers, **headers} + return headers diff --git a/litellm/llms/predibase/common_utils.py b/litellm/llms/predibase/common_utils.py new file mode 100644 index 0000000000..f1506ce219 --- /dev/null +++ b/litellm/llms/predibase/common_utils.py @@ -0,0 +1,23 @@ +from typing import Optional, Union + +import httpx + +from litellm.llms.base_llm.transformation import BaseLLMException + + +class PredibaseError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + headers: Optional[Union[httpx.Headers, dict]] = None, + ): + super().__init__( + status_code=status_code, + message=message, + request=request, + response=response, + headers=headers, + ) diff --git a/litellm/main.py b/litellm/main.py index 62054df3ae..4350689d00 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -104,13 +104,12 @@ from .llms.azure_ai.embed import AzureAIEmbedding from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.image.image_handler import BedrockImageGeneration -from .llms.codestral.completion.handler import CodestralTextCompletion from .llms.cohere.embed import handler as cohere_embed from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.databricks.chat.handler import DatabricksChatCompletion from .llms.databricks.embed.handler import DatabricksEmbeddingHandler -from .llms.deprecated_providers import palm, aleph_alpha +from .llms.deprecated_providers import aleph_alpha, palm from .llms.groq.chat.handler import GroqChatCompletion from .llms.huggingface.chat.handler import Huggingface from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion @@ -121,10 +120,11 @@ from .llms.openai.openai import OpenAIChatCompletion from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.openai_like.chat.handler import OpenAILikeChatHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler -from .llms.predibase import PredibaseChatCompletion +from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.replicate.chat.handler import completion as replicate_chat_completion from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.completion.handler import SagemakerLLM +from .llms.text_completion_codestral import CodestralTextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.triton import TritonChatCompletion from .llms.vertex_ai import vertex_ai_non_gemini @@ -2328,6 +2328,11 @@ def completion( # type: ignore # noqa: PLR0915 or get_secret("PREDIBASE_TENANT_ID") ) + if tenant_id is None: + raise ValueError( + "Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=)`) or in env - `PREDIBASE_TENANT_ID`." + ) + api_base = ( api_base or optional_params.pop("api_base", None) diff --git a/litellm/utils.py b/litellm/utils.py index 7b067c4b6f..a55708f94b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -369,8 +369,14 @@ def function_setup( # noqa: PLR0915 litellm._async_success_callback.append(callback) removed_async_items.append(index) elif callback in litellm._known_custom_logger_compatible_callbacks: - callback_class = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore - callback, internal_usage_cache=None, llm_router=None # type: ignore + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + callback_class = _init_custom_logger_compatible_class( + callback, # type: ignore + internal_usage_cache=None, + llm_router=None, # type: ignore ) # don't double add a callback @@ -2941,7 +2947,14 @@ def get_optional_params( # noqa: PLR0915 ) _check_valid_arg(supported_params=supported_params) optional_params = litellm.PredibaseConfig().map_openai_params( - non_default_params=non_default_params, optional_params=optional_params + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "huggingface": ## check if unsupported param passed in @@ -4102,7 +4115,6 @@ def get_api_base( partner=VertexPartnerProvider.claude, ) else: - if stream: _api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format( _optional_params.vertex_location, @@ -6359,6 +6371,8 @@ class ProviderConfigManager: return litellm.VLLMConfig() elif litellm.LlmProviders.OLLAMA == provider: return litellm.OllamaConfig() + elif litellm.LlmProviders.PREDIBASE == provider: + return litellm.PredibaseConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index f0e8153abd..905c60342d 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -204,12 +204,16 @@ def test_all_model_configs(): drop_params=False, ) == {"num_predict": 10} - from litellm.llms.predibase import PredibaseConfig + from litellm.llms.predibase.chat.transformation import PredibaseConfig - assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params() + assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params( + model="llama3" + ) assert PredibaseConfig().map_openai_params( - {"max_completion_tokens": 10}, - {}, + model="llama3", + non_default_params={"max_completion_tokens": 10}, + optional_params={}, + drop_params=False, ) == {"max_new_tokens": 10} from litellm.llms.codestral.completion.transformation import (