Krrish Dholakia 2024-12-11 01:00:33 -08:00
parent 78d132c1fb
commit b9b34a7b99
8 changed files with 255 additions and 176 deletions

View file

@ -1072,7 +1072,7 @@ from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig

View file

@ -200,4 +200,6 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.OpenAITextCompletionConfig().get_supported_openai_params( return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
model=model model=model
) )
elif custom_llm_provider == "predibase":
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
return None return None

View file

@ -25,36 +25,9 @@ from litellm.llms.custom_httpx.http_handler import (
) )
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import PredibaseError
class PredibaseError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.predibase.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def make_call( async def make_call(
@ -86,143 +59,10 @@ async def make_call(
return completion_stream return completion_stream
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
"""
adapter_id: Optional[str] = None
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: bool = True # enables returning logprobs + best of
max_new_tokens: int = (
256 # openai default - requests hang if max_new_tokens not given
)
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_completion_tokens",
"max_tokens",
"top_p",
"stop",
"n",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class PredibaseChatCompletion(BaseLLM): class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def _validate_environment(
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
if tenant_id is None:
raise ValueError(
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str): def output_parser(self, generated_text: str):
""" """
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
@ -398,7 +238,13 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers: dict = {}, headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id) headers = litellm.PredibaseConfig().validate_environment(
api_key=api_key,
headers=headers,
messages=messages,
optional_params=optional_params,
model=model,
)
completion_url = "" completion_url = ""
input_text = "" input_text = ""
base_url = "https://serving.app.predibase.com" base_url = "https://serving.app.predibase.com"

View file

@ -0,0 +1,185 @@
import types
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import PredibaseError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class PredibaseConfig(BaseConfig):
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
"""
adapter_id: Optional[str] = None
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: bool = True # enables returning logprobs + best of
max_new_tokens: int = (
256 # openai default - requests hang if max_new_tokens not given
)
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_completion_tokens",
"max_tokens",
"top_p",
"stop",
"n",
"response_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
if param == "response_format":
optional_params["response_format"] = value
return optional_params
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"Predibase transformation currently done in handler.py. Need to migrate to this file."
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"Predibase transformation currently done in handler.py. Need to migrate to this file."
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return PredibaseError(
status_code=status_code, message=error_message, headers=headers
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
default_headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if headers is not None and isinstance(headers, dict):
headers = {**default_headers, **headers}
return headers

View file

@ -0,0 +1,23 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class PredibaseError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)

View file

@ -104,13 +104,12 @@ from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.codestral.completion.handler import CodestralTextCompletion
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import palm, aleph_alpha from .llms.deprecated_providers import aleph_alpha, palm
from .llms.groq.chat.handler import GroqChatCompletion from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface.chat.handler import Huggingface from .llms.huggingface.chat.handler import Huggingface
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
@ -121,10 +120,11 @@ from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai_like.chat.handler import OpenAILikeChatHandler from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.predibase import PredibaseChatCompletion from .llms.predibase.chat.handler import PredibaseChatCompletion
from .llms.replicate.chat.handler import completion as replicate_chat_completion from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai import vertex_ai_non_gemini
@ -2328,6 +2328,11 @@ def completion( # type: ignore # noqa: PLR0915
or get_secret("PREDIBASE_TENANT_ID") or get_secret("PREDIBASE_TENANT_ID")
) )
if tenant_id is None:
raise ValueError(
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
)
api_base = ( api_base = (
api_base api_base
or optional_params.pop("api_base", None) or optional_params.pop("api_base", None)

View file

@ -369,8 +369,14 @@ def function_setup( # noqa: PLR0915
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback)
removed_async_items.append(index) removed_async_items.append(index)
elif callback in litellm._known_custom_logger_compatible_callbacks: elif callback in litellm._known_custom_logger_compatible_callbacks:
callback_class = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore from litellm.litellm_core_utils.litellm_logging import (
callback, internal_usage_cache=None, llm_router=None # type: ignore _init_custom_logger_compatible_class,
)
callback_class = _init_custom_logger_compatible_class(
callback, # type: ignore
internal_usage_cache=None,
llm_router=None, # type: ignore
) )
# don't double add a callback # don't double add a callback
@ -2941,7 +2947,14 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.PredibaseConfig().map_openai_params( optional_params = litellm.PredibaseConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
## check if unsupported param passed in ## check if unsupported param passed in
@ -4102,7 +4115,6 @@ def get_api_base(
partner=VertexPartnerProvider.claude, partner=VertexPartnerProvider.claude,
) )
else: else:
if stream: if stream:
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format( _api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
_optional_params.vertex_location, _optional_params.vertex_location,
@ -6359,6 +6371,8 @@ class ProviderConfigManager:
return litellm.VLLMConfig() return litellm.VLLMConfig()
elif litellm.LlmProviders.OLLAMA == provider: elif litellm.LlmProviders.OLLAMA == provider:
return litellm.OllamaConfig() return litellm.OllamaConfig()
elif litellm.LlmProviders.PREDIBASE == provider:
return litellm.PredibaseConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()

View file

@ -204,12 +204,16 @@ def test_all_model_configs():
drop_params=False, drop_params=False,
) == {"num_predict": 10} ) == {"num_predict": 10}
from litellm.llms.predibase import PredibaseConfig from litellm.llms.predibase.chat.transformation import PredibaseConfig
assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params() assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params(
model="llama3"
)
assert PredibaseConfig().map_openai_params( assert PredibaseConfig().map_openai_params(
{"max_completion_tokens": 10}, model="llama3",
{}, non_default_params={"max_completion_tokens": 10},
optional_params={},
drop_params=False,
) == {"max_new_tokens": 10} ) == {"max_new_tokens": 10}
from litellm.llms.codestral.completion.transformation import ( from litellm.llms.codestral.completion.transformation import (