Add inference providers support for Hugging Face (#8258) (#9738) (#9773)

* Add inference providers support for Hugging Face (#8258)

* add first version of inference providers for huggingface

* temporarily skipping tests

* Add documentation

* Fix titles

* remove max_retries from params and clean up

* add suggestions

* use llm http handler

* update doc

* add suggestions

* run formatters

* add tests

* revert

* revert

* rename file

* set maxsize for lru cache

* fix embeddings

* fix inference url

* fix tests following breaking change in main

* use ChatCompletionRequest

* fix tests and lint

* [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749)

* remove or fix tests

* fix link in doc

* fix(config_settings.md): document hf api key

---------

Co-authored-by: célina <hanouticelina@gmail.com>
This commit is contained in:
Krish Dholakia 2025-04-05 10:50:15 -07:00 committed by GitHub
parent 0d503ad8ad
commit 34bdf36eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2052 additions and 2456 deletions

View file

@ -141,7 +141,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import aleph_alpha, palm
from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface.chat.handler import Huggingface
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.ollama.completion import handler as ollama
from .llms.oobabooga.chat import oobabooga
@ -221,7 +221,7 @@ azure_chat_completions = AzureChatCompletion()
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
azure_text_completions = AzureTextCompletion()
azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface()
huggingface_embed = HuggingFaceEmbedding()
predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
bedrock_converse_chat_completion = BedrockConverseLLM()
@ -2141,7 +2141,6 @@ def completion( # type: ignore # noqa: PLR0915
response = model_response
elif custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
huggingface_key = (
api_key
or litellm.huggingface_key
@ -2150,40 +2149,23 @@ def completion( # type: ignore # noqa: PLR0915
or litellm.api_key
)
hf_headers = headers or litellm.headers
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = huggingface.completion(
response = base_llm_http_handler.completion(
model=model,
messages=messages,
api_base=api_base, # type: ignore
headers=hf_headers or {},
headers=hf_headers,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=huggingface_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout, # type: ignore
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="huggingface",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "oobabooga":
custom_llm_provider = "oobabooga"
model_response = oobabooga.completion(
@ -3623,7 +3605,7 @@ def embedding( # noqa: PLR0915
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
) # type: ignore
response = huggingface.embedding(
response = huggingface_embed.embedding(
model=model,
input=input,
encoding=encoding, # type: ignore