Add inference providers support for Hugging Face (#8258) (#9738) (#9773)

* Add inference providers support for Hugging Face (#8258)

* add first version of inference providers for huggingface

* temporarily skipping tests

* Add documentation

* Fix titles

* remove max_retries from params and clean up

* add suggestions

* use llm http handler

* update doc

* add suggestions

* run formatters

* add tests

* revert

* revert

* rename file

* set maxsize for lru cache

* fix embeddings

* fix inference url

* fix tests following breaking change in main

* use ChatCompletionRequest

* fix tests and lint

* [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749)

* remove or fix tests

* fix link in doc

* fix(config_settings.md): document hf api key

---------

Co-authored-by: célina <hanouticelina@gmail.com>
This commit is contained in:
Krish Dholakia 2025-04-05 10:50:15 -07:00 committed by GitHub
parent 0d503ad8ad
commit 34bdf36eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2052 additions and 2456 deletions

View file

@ -1,18 +1,30 @@
import os
from functools import lru_cache
from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
HF_HUB_URL = "https://huggingface.co"
class HuggingfaceError(BaseLLMException):
class HuggingFaceError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
hf_tasks = Literal[
@ -43,3 +55,48 @@ def output_parser(generated_text: str):
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
@lru_cache(maxsize=128)
def _fetch_inference_provider_mapping(model: str) -> dict:
"""
Fetch provider mappings for a model from the Hugging Face Hub.
Args:
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
Returns:
dict: The inference provider mapping for the model
Raises:
ValueError: If no provider mapping is found
HuggingFaceError: If the API request fails
"""
headers = {"Accept": "application/json"}
if os.getenv("HUGGINGFACE_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
path = f"{HF_HUB_URL}/api/models/{model}"
params = {"expand": ["inferenceProviderMapping"]}
try:
response = httpx.get(path, headers=headers, params=params)
response.raise_for_status()
provider_mapping = response.json().get("inferenceProviderMapping")
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
except httpx.HTTPError as e:
if hasattr(e, "response"):
status_code = getattr(e.response, "status_code", 500)
headers = getattr(e.response, "headers", {})
else:
status_code = 500
headers = {}
raise HuggingFaceError(
message=f"Failed to fetch provider mapping: {str(e)}",
status_code=status_code,
headers=headers,
)