mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7165
Closes https://github.com/BerriAI/litellm/pull/7165
This commit is contained in:
parent
b79db3616c
commit
6493eaf2ee
8 changed files with 255 additions and 176 deletions
|
@ -1072,7 +1072,7 @@ from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||||
from .llms.predibase import PredibaseConfig
|
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||||
from .llms.replicate.chat.transformation import ReplicateConfig
|
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
|
|
|
@ -200,4 +200,6 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "predibase":
|
||||||
|
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -25,36 +25,9 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import PredibaseError
|
||||||
|
|
||||||
class PredibaseError(Exception):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
status_code,
|
|
||||||
message,
|
|
||||||
request: Optional[httpx.Request] = None,
|
|
||||||
response: Optional[httpx.Response] = None,
|
|
||||||
):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
if request is not None:
|
|
||||||
self.request = request
|
|
||||||
else:
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST",
|
|
||||||
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
|
||||||
)
|
|
||||||
if response is not None:
|
|
||||||
self.response = response
|
|
||||||
else:
|
|
||||||
self.response = httpx.Response(
|
|
||||||
status_code=status_code, request=self.request
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
async def make_call(
|
||||||
|
@ -86,143 +59,10 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class PredibaseConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
adapter_id: Optional[str] = None
|
|
||||||
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
|
||||||
best_of: Optional[int] = None
|
|
||||||
decoder_input_details: Optional[bool] = None
|
|
||||||
details: bool = True # enables returning logprobs + best of
|
|
||||||
max_new_tokens: int = (
|
|
||||||
256 # openai default - requests hang if max_new_tokens not given
|
|
||||||
)
|
|
||||||
repetition_penalty: Optional[float] = None
|
|
||||||
return_full_text: Optional[bool] = (
|
|
||||||
False # by default don't return the input as part of the output
|
|
||||||
)
|
|
||||||
seed: Optional[int] = None
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
truncate: Optional[int] = None
|
|
||||||
typical_p: Optional[float] = None
|
|
||||||
watermark: Optional[bool] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
best_of: Optional[int] = None,
|
|
||||||
decoder_input_details: Optional[bool] = None,
|
|
||||||
details: Optional[bool] = None,
|
|
||||||
max_new_tokens: Optional[int] = None,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
return_full_text: Optional[bool] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
truncate: Optional[int] = None,
|
|
||||||
typical_p: Optional[float] = None,
|
|
||||||
watermark: Optional[bool] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"temperature",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"max_tokens",
|
|
||||||
"top_p",
|
|
||||||
"stop",
|
|
||||||
"n",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
|
||||||
if param == "temperature":
|
|
||||||
if value == 0.0 or value == 0:
|
|
||||||
# hugging face exception raised when temp==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
|
||||||
value = 0.01
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "n":
|
|
||||||
optional_params["best_of"] = value
|
|
||||||
optional_params["do_sample"] = (
|
|
||||||
True # Need to sample if you want best of for hf inference endpoints
|
|
||||||
)
|
|
||||||
if param == "stream":
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop"] = value
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
# HF TGI raises the following exception when max_new_tokens==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
|
||||||
if value == 0:
|
|
||||||
value = 1
|
|
||||||
optional_params["max_new_tokens"] = value
|
|
||||||
if param == "echo":
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
|
||||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
|
||||||
optional_params["decoder_input_details"] = True
|
|
||||||
if param == "response_format":
|
|
||||||
optional_params["response_format"] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
class PredibaseChatCompletion(BaseLLM):
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _validate_environment(
|
|
||||||
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
|
|
||||||
) -> dict:
|
|
||||||
if api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
|
||||||
)
|
|
||||||
if tenant_id is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
|
||||||
)
|
|
||||||
headers = {
|
|
||||||
"content-type": "application/json",
|
|
||||||
"Authorization": "Bearer {}".format(api_key),
|
|
||||||
}
|
|
||||||
if user_headers is not None and isinstance(user_headers, dict):
|
|
||||||
headers = {**headers, **user_headers}
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def output_parser(self, generated_text: str):
|
def output_parser(self, generated_text: str):
|
||||||
"""
|
"""
|
||||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
@ -398,7 +238,13 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
|
headers = litellm.PredibaseConfig().validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
base_url = "https://serving.app.predibase.com"
|
base_url = "https://serving.app.predibase.com"
|
185
litellm/llms/predibase/chat/transformation.py
Normal file
185
litellm/llms/predibase/chat/transformation.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
import types
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
from ..common_utils import PredibaseError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||||
|
"""
|
||||||
|
|
||||||
|
adapter_id: Optional[str] = None
|
||||||
|
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: bool = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: int = (
|
||||||
|
256 # openai default - requests hang if max_new_tokens not given
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[bool] = (
|
||||||
|
False # by default don't return the input as part of the output
|
||||||
|
)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Predibase transformation currently done in handler.py. Need to migrate to this file."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return PredibaseError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||||
|
)
|
||||||
|
|
||||||
|
default_headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if headers is not None and isinstance(headers, dict):
|
||||||
|
headers = {**default_headers, **headers}
|
||||||
|
return headers
|
23
litellm/llms/predibase/common_utils.py
Normal file
23
litellm/llms/predibase/common_utils.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
request=request,
|
||||||
|
response=response,
|
||||||
|
headers=headers,
|
||||||
|
)
|
|
@ -104,13 +104,12 @@ from .llms.azure_ai.embed import AzureAIEmbedding
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||||
from .llms.codestral.completion.handler import CodestralTextCompletion
|
|
||||||
from .llms.cohere.embed import handler as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||||
from .llms.deprecated_providers import palm, aleph_alpha
|
from .llms.deprecated_providers import aleph_alpha, palm
|
||||||
from .llms.groq.chat.handler import GroqChatCompletion
|
from .llms.groq.chat.handler import GroqChatCompletion
|
||||||
from .llms.huggingface.chat.handler import Huggingface
|
from .llms.huggingface.chat.handler import Huggingface
|
||||||
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||||
|
@ -121,10 +120,11 @@ from .llms.openai.openai import OpenAIChatCompletion
|
||||||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||||
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||||
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||||
|
@ -2328,6 +2328,11 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or get_secret("PREDIBASE_TENANT_ID")
|
or get_secret("PREDIBASE_TENANT_ID")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tenant_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||||
|
)
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or optional_params.pop("api_base", None)
|
or optional_params.pop("api_base", None)
|
||||||
|
|
|
@ -369,8 +369,14 @@ def function_setup( # noqa: PLR0915
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
elif callback in litellm._known_custom_logger_compatible_callbacks:
|
elif callback in litellm._known_custom_logger_compatible_callbacks:
|
||||||
callback_class = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
callback, internal_usage_cache=None, llm_router=None # type: ignore
|
_init_custom_logger_compatible_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
callback_class = _init_custom_logger_compatible_class(
|
||||||
|
callback, # type: ignore
|
||||||
|
internal_usage_cache=None,
|
||||||
|
llm_router=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# don't double add a callback
|
# don't double add a callback
|
||||||
|
@ -2941,7 +2947,14 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.PredibaseConfig().map_openai_params(
|
optional_params = litellm.PredibaseConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
|
@ -4102,7 +4115,6 @@ def get_api_base(
|
||||||
partner=VertexPartnerProvider.claude,
|
partner=VertexPartnerProvider.claude,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||||
_optional_params.vertex_location,
|
_optional_params.vertex_location,
|
||||||
|
@ -6359,6 +6371,8 @@ class ProviderConfigManager:
|
||||||
return litellm.VLLMConfig()
|
return litellm.VLLMConfig()
|
||||||
elif litellm.LlmProviders.OLLAMA == provider:
|
elif litellm.LlmProviders.OLLAMA == provider:
|
||||||
return litellm.OllamaConfig()
|
return litellm.OllamaConfig()
|
||||||
|
elif litellm.LlmProviders.PREDIBASE == provider:
|
||||||
|
return litellm.PredibaseConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -204,12 +204,16 @@ def test_all_model_configs():
|
||||||
drop_params=False,
|
drop_params=False,
|
||||||
) == {"num_predict": 10}
|
) == {"num_predict": 10}
|
||||||
|
|
||||||
from litellm.llms.predibase import PredibaseConfig
|
from litellm.llms.predibase.chat.transformation import PredibaseConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in PredibaseConfig().get_supported_openai_params(
|
||||||
|
model="llama3"
|
||||||
|
)
|
||||||
assert PredibaseConfig().map_openai_params(
|
assert PredibaseConfig().map_openai_params(
|
||||||
{"max_completion_tokens": 10},
|
model="llama3",
|
||||||
{},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_new_tokens": 10}
|
) == {"max_new_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.codestral.completion.transformation import (
|
from litellm.llms.codestral.completion.transformation import (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue