mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* Add inference providers support for Hugging Face (#8258) * add first version of inference providers for huggingface * temporarily skipping tests * Add documentation * Fix titles * remove max_retries from params and clean up * add suggestions * use llm http handler * update doc * add suggestions * run formatters * add tests * revert * revert * rename file * set maxsize for lru cache * fix embeddings * fix inference url * fix tests following breaking change in main * use ChatCompletionRequest * fix tests and lint * [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749) * remove or fix tests * fix link in doc * fix(config_settings.md): document hf api key --------- Co-authored-by: célina <hanouticelina@gmail.com>
This commit is contained in:
parent
0d503ad8ad
commit
34bdf36eab
24 changed files with 2052 additions and 2456 deletions
|
@ -1,18 +1,30 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
HF_HUB_URL = "https://huggingface.co"
|
||||
|
||||
class HuggingfaceError(BaseLLMException):
|
||||
|
||||
class HuggingFaceError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=request,
|
||||
response=response,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
hf_tasks = Literal[
|
||||
|
@ -43,3 +55,48 @@ def output_parser(generated_text: str):
|
|||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _fetch_inference_provider_mapping(model: str) -> dict:
|
||||
"""
|
||||
Fetch provider mappings for a model from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
|
||||
|
||||
Returns:
|
||||
dict: The inference provider mapping for the model
|
||||
|
||||
Raises:
|
||||
ValueError: If no provider mapping is found
|
||||
HuggingFaceError: If the API request fails
|
||||
"""
|
||||
headers = {"Accept": "application/json"}
|
||||
if os.getenv("HUGGINGFACE_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
|
||||
|
||||
path = f"{HF_HUB_URL}/api/models/{model}"
|
||||
params = {"expand": ["inferenceProviderMapping"]}
|
||||
|
||||
try:
|
||||
response = httpx.get(path, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
provider_mapping = response.json().get("inferenceProviderMapping")
|
||||
|
||||
if provider_mapping is None:
|
||||
raise ValueError(f"No provider mapping found for model {model}")
|
||||
|
||||
return provider_mapping
|
||||
except httpx.HTTPError as e:
|
||||
if hasattr(e, "response"):
|
||||
status_code = getattr(e.response, "status_code", 500)
|
||||
headers = getattr(e.response, "headers", {})
|
||||
else:
|
||||
status_code = 500
|
||||
headers = {}
|
||||
raise HuggingFaceError(
|
||||
message=f"Failed to fetch provider mapping: {str(e)}",
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue