litellm-mirror/litellm/llms/huggingface/common_utils.py
Krish Dholakia 34bdf36eab
Add inference providers support for Hugging Face (#8258) (#9738) (#9773)
* Add inference providers support for Hugging Face (#8258)

* add first version of inference providers for huggingface

* temporarily skipping tests

* Add documentation

* Fix titles

* remove max_retries from params and clean up

* add suggestions

* use llm http handler

* update doc

* add suggestions

* run formatters

* add tests

* revert

* revert

* rename file

* set maxsize for lru cache

* fix embeddings

* fix inference url

* fix tests following breaking change in main

* use ChatCompletionRequest

* fix tests and lint

* [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749)

* remove or fix tests

* fix link in doc

* fix(config_settings.md): document hf api key

---------

Co-authored-by: célina <hanouticelina@gmail.com>
2025-04-05 10:50:15 -07:00

102 lines
3 KiB
Python

import os
from functools import lru_cache
from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
HF_HUB_URL = "https://huggingface.co"
class HuggingFaceError(BaseLLMException):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
@lru_cache(maxsize=128)
def _fetch_inference_provider_mapping(model: str) -> dict:
"""
Fetch provider mappings for a model from the Hugging Face Hub.
Args:
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
Returns:
dict: The inference provider mapping for the model
Raises:
ValueError: If no provider mapping is found
HuggingFaceError: If the API request fails
"""
headers = {"Accept": "application/json"}
if os.getenv("HUGGINGFACE_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
path = f"{HF_HUB_URL}/api/models/{model}"
params = {"expand": ["inferenceProviderMapping"]}
try:
response = httpx.get(path, headers=headers, params=params)
response.raise_for_status()
provider_mapping = response.json().get("inferenceProviderMapping")
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
except httpx.HTTPError as e:
if hasattr(e, "response"):
status_code = getattr(e.response, "status_code", 500)
headers = getattr(e.response, "headers", {})
else:
status_code = 500
headers = {}
raise HuggingFaceError(
message=f"Failed to fetch provider mapping: {str(e)}",
status_code=status_code,
headers=headers,
)