mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* Add inference providers support for Hugging Face (#8258) * add first version of inference providers for huggingface * temporarily skipping tests * Add documentation * Fix titles * remove max_retries from params and clean up * add suggestions * use llm http handler * update doc * add suggestions * run formatters * add tests * revert * revert * rename file * set maxsize for lru cache * fix embeddings * fix inference url * fix tests following breaking change in main * use ChatCompletionRequest * fix tests and lint * [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749) * remove or fix tests * fix link in doc * fix(config_settings.md): document hf api key --------- Co-authored-by: célina <hanouticelina@gmail.com>
141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
|
|
LoggingClass = LiteLLMLoggingObj
|
|
else:
|
|
LoggingClass = Any
|
|
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
|
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BASE_URL = "https://router.huggingface.co"
|
|
|
|
|
|
class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|
"""
|
|
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
|
|
"""
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
default_headers = {
|
|
"content-type": "application/json",
|
|
}
|
|
if api_key is not None:
|
|
default_headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
headers = {**headers, **default_headers}
|
|
|
|
return headers
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
|
|
|
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
|
"""
|
|
Get the API base for the Huggingface API.
|
|
|
|
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
|
"""
|
|
if model.startswith(("http://", "https://")):
|
|
base_url = model
|
|
elif base_url is None:
|
|
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
|
|
return base_url
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
model: str,
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
stream: Optional[bool] = None,
|
|
) -> str:
|
|
"""
|
|
Get the complete URL for the API call.
|
|
For provider-specific routing through huggingface
|
|
"""
|
|
# 1. Check if api_base is provided
|
|
if api_base is not None:
|
|
complete_url = api_base
|
|
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
|
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
|
elif model.startswith(("http://", "https://")):
|
|
complete_url = model
|
|
# 4. Default construction with provider
|
|
else:
|
|
# Parse provider and model
|
|
first_part, remaining = model.split("/", 1)
|
|
if "/" in remaining:
|
|
provider = first_part
|
|
else:
|
|
provider = "hf-inference"
|
|
|
|
if provider == "hf-inference":
|
|
route = f"{provider}/models/{model}/v1/chat/completions"
|
|
elif provider == "novita":
|
|
route = f"{provider}/chat/completions"
|
|
else:
|
|
route = f"{provider}/v1/chat/completions"
|
|
complete_url = f"{BASE_URL}/{route}"
|
|
|
|
# Ensure URL doesn't end with a slash
|
|
complete_url = complete_url.rstrip("/")
|
|
return complete_url
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
if "max_retries" in optional_params:
|
|
logger.warning("`max_retries` is not supported. It will be ignored.")
|
|
optional_params.pop("max_retries", None)
|
|
first_part, remaining = model.split("/", 1)
|
|
if "/" in remaining:
|
|
provider = first_part
|
|
model_id = remaining
|
|
else:
|
|
provider = "hf-inference"
|
|
model_id = model
|
|
provider_mapping = _fetch_inference_provider_mapping(model_id)
|
|
if provider not in provider_mapping:
|
|
raise HuggingFaceError(
|
|
message=f"Model {model_id} is not supported for provider {provider}",
|
|
status_code=404,
|
|
headers={},
|
|
)
|
|
provider_mapping = provider_mapping[provider]
|
|
if provider_mapping["status"] == "staging":
|
|
logger.warning(
|
|
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
|
|
)
|
|
mapped_model = provider_mapping["providerId"]
|
|
messages = self._transform_messages(messages=messages, model=mapped_model)
|
|
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|