litellm-mirror/litellm/llms/huggingface/chat/transformation.py
Krish Dholakia 34bdf36eab
Add inference providers support for Hugging Face (#8258) (#9738) (#9773)
* Add inference providers support for Hugging Face (#8258)

* add first version of inference providers for huggingface

* temporarily skipping tests

* Add documentation

* Fix titles

* remove max_retries from params and clean up

* add suggestions

* use llm http handler

* update doc

* add suggestions

* run formatters

* add tests

* revert

* revert

* rename file

* set maxsize for lru cache

* fix embeddings

* fix inference url

* fix tests following breaking change in main

* use ChatCompletionRequest

* fix tests and lint

* [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749)

* remove or fix tests

* fix link in doc

* fix(config_settings.md): document hf api key

---------

Co-authored-by: célina <hanouticelina@gmail.com>
2025-04-05 10:50:15 -07:00

141 lines
4.8 KiB
Python

import logging
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
class HuggingFaceChatConfig(OpenAIGPTConfig):
"""
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
"""
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = f"Bearer {api_key}"
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if model.startswith(("http://", "https://")):
base_url = model
elif base_url is None:
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
return base_url
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
For provider-specific routing through huggingface
"""
# 1. Check if api_base is provided
if api_base is not None:
complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
elif model.startswith(("http://", "https://")):
complete_url = model
# 4. Default construction with provider
else:
# Parse provider and model
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
else:
provider = "hf-inference"
if provider == "hf-inference":
route = f"{provider}/models/{model}/v1/chat/completions"
elif provider == "novita":
route = f"{provider}/chat/completions"
else:
route = f"{provider}/v1/chat/completions"
complete_url = f"{BASE_URL}/{route}"
# Ensure URL doesn't end with a slash
complete_url = complete_url.rstrip("/")
return complete_url
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
if "max_retries" in optional_params:
logger.warning("`max_retries` is not supported. It will be ignored.")
optional_params.pop("max_retries", None)
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
model_id = remaining
else:
provider = "hf-inference"
model_id = model
provider_mapping = _fetch_inference_provider_mapping(model_id)
if provider not in provider_mapping:
raise HuggingFaceError(
message=f"Model {model_id} is not supported for provider {provider}",
status_code=404,
headers={},
)
provider_mapping = provider_mapping[provider]
if provider_mapping["status"] == "staging":
logger.warning(
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))