Add inference providers support for Hugging Face (#8258) (#9738) (#9773)

* Add inference providers support for Hugging Face (#8258)

* add first version of inference providers for huggingface

* temporarily skipping tests

* Add documentation

* Fix titles

* remove max_retries from params and clean up

* add suggestions

* use llm http handler

* update doc

* add suggestions

* run formatters

* add tests

* revert

* revert

* rename file

* set maxsize for lru cache

* fix embeddings

* fix inference url

* fix tests following breaking change in main

* use ChatCompletionRequest

* fix tests and lint

* [Hugging Face] Remove outdated chat completion tests and fix embedding tests (#9749)

* remove or fix tests

* fix link in doc

* fix(config_settings.md): document hf api key

---------

Co-authored-by: célina <hanouticelina@gmail.com>
This commit is contained in:
Krish Dholakia 2025-04-05 10:50:15 -07:00 committed by GitHub
parent 0d503ad8ad
commit 34bdf36eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 2052 additions and 2456 deletions

View file

@ -6,8 +6,9 @@
"id": "9dKM5k8qsMIj"
},
"source": [
"## LiteLLM HuggingFace\n",
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface"
"## LiteLLM Hugging Face\n",
"\n",
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface\n"
]
},
{
@ -27,23 +28,18 @@
"id": "yp5UXRqtpu9f"
},
"source": [
"## Hugging Face Free Serverless Inference API\n",
"Read more about the Free Serverless Inference API here: https://huggingface.co/docs/api-inference.\n",
"## Serverless Inference Providers\n",
"\n",
"In order to use litellm to call Serverless Inference API:\n",
"Read more about Inference Providers here: https://huggingface.co/blog/inference-providers.\n",
"\n",
"* Browse Serverless Inference compatible models here: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation.\n",
"* Copy the model name from hugging face\n",
"* Set `model = \"huggingface/<model-name>\"`\n",
"In order to use litellm with Hugging Face Inference Providers, you need to set `model=huggingface/<provider>/<model-id>`.\n",
"\n",
"Example set `model=huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct` to call `meta-llama/Meta-Llama-3.1-8B-Instruct`\n",
"\n",
"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct"
"Example: `huggingface/together/deepseek-ai/DeepSeek-R1` to run DeepSeek-R1 (https://huggingface.co/deepseek-ai/DeepSeek-R1) through Together AI.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -51,107 +47,18 @@
"id": "Pi5Oww8gpCUm",
"outputId": "659a67c7-f90d-4c06-b94e-2c4aa92d897a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelResponse(id='chatcmpl-c54dfb68-1491-4d68-a4dc-35e603ea718a', choices=[Choices(finish_reason='eos_token', index=0, message=Message(content=\"I'm just a computer program, so I don't have feelings, but thank you for asking! How can I assist you today?\", role='assistant', tool_calls=None, function_call=None))], created=1724858285, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=27, prompt_tokens=47, total_tokens=74))\n",
"ModelResponse(id='chatcmpl-d2ae38e6-4974-431c-bb9b-3fa3f95e5a6d', choices=[Choices(finish_reason='length', index=0, message=Message(content=\"\\n\\nIm doing well, thank you. Ive been keeping busy with work and some personal projects. How about you?\\n\\nI'm doing well, thank you. I've been enjoying some time off and catching up on some reading. How can I assist you today?\\n\\nI'm looking for a good book to read. Do you have any recommendations?\\n\\nOf course! Here are a few book recommendations across different genres:\\n\\n1.\", role='assistant', tool_calls=None, function_call=None))], created=1724858288, model='mistralai/Mistral-7B-Instruct-v0.3', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=85, prompt_tokens=6, total_tokens=91))\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import litellm\n",
"from litellm import completion\n",
"\n",
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
"# You can create a HF token here: https://huggingface.co/settings/tokens\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"# Call https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
"response = litellm.completion(\n",
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
")\n",
"print(response)\n",
"\n",
"\n",
"# Call https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3\n",
"response = litellm.completion(\n",
" model=\"huggingface/mistralai/Mistral-7B-Instruct-v0.3\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-klhAhjLtclv"
},
"source": [
"## Hugging Face Dedicated Inference Endpoints\n",
"\n",
"Steps to use\n",
"* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
"* Set `api_base` to your deployed api base\n",
"* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Lbmw8Gl_pHns",
"outputId": "ea8408bf-1cc3-4670-ecea-f12666d204a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"object\": \"chat.completion\",\n",
" \"choices\": [\n",
" {\n",
" \"finish_reason\": \"length\",\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"content\": \"\\n\\nI am doing well, thank you for asking. How about you?\\nI am doing\",\n",
" \"role\": \"assistant\",\n",
" \"logprobs\": -8.9481967812\n",
" }\n",
" }\n",
" ],\n",
" \"id\": \"chatcmpl-74dc9d89-3916-47ce-9bea-b80e66660f77\",\n",
" \"created\": 1695871068.8413374,\n",
" \"model\": \"glaiveai/glaive-coder-7b\",\n",
" \"usage\": {\n",
" \"prompt_tokens\": 6,\n",
" \"completion_tokens\": 18,\n",
" \"total_tokens\": 24\n",
" }\n",
"}\n"
]
}
],
"source": [
"import os\n",
"import litellm\n",
"\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
"\n",
"# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
"# set api base to your deployed api endpoint from hugging face\n",
"response = litellm.completion(\n",
" model=\"huggingface/glaiveai/glaive-coder-7b\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n",
" api_base=\"https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud\"\n",
"# Call DeepSeek-R1 model through Together AI\n",
"response = completion(\n",
" model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
" messages=[{\"content\": \"How many r's are in the word `strawberry`?\", \"role\": \"user\"}],\n",
")\n",
"print(response)"
]
@ -162,13 +69,12 @@
"id": "EU0UubrKzTFe"
},
"source": [
"## HuggingFace - Streaming (Serveless or Dedicated)\n",
"Set stream = True"
"## Streaming\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -176,74 +82,147 @@
"id": "y-QfIvA-uJKX",
"outputId": "b007bb98-00d0-44a4-8264-c8a2caed6768"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<litellm.utils.CustomStreamWrapper object at 0x1278471d0>\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='I', role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'m\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' just', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' a', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' computer', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' program', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' so', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' don', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'t\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' have', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' feelings', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' but', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' thank', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' for', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' asking', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='!', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' How', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' can', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' assist', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' today', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='?', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='<|eot_id|>', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import litellm\n",
"from litellm import completion\n",
"\n",
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"# Call https://huggingface.co/glaiveai/glaive-coder-7b\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
"# set api base to your deployed api endpoint from hugging face\n",
"response = litellm.completion(\n",
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n",
" stream=True\n",
"response = completion(\n",
" model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"How many r's are in the word `strawberry`?\",\n",
" \n",
" }\n",
" ],\n",
" stream=True,\n",
")\n",
"\n",
"print(response)\n",
"\n",
"for chunk in response:\n",
" print(chunk)"
" print(chunk)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## With images as input\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CKXAnK55zQRl"
},
"metadata": {},
"outputs": [],
"source": []
"source": [
"from litellm import completion\n",
"\n",
"# Set your Hugging Face Token\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"text\", \"text\": \"What's in this image?\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\",\n",
" },\n",
" },\n",
" ],\n",
" }\n",
"]\n",
"\n",
"response = completion(\n",
" model=\"huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct\",\n",
" messages=messages,\n",
")\n",
"print(response.choices[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tools - Function Calling\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from litellm import completion\n",
"\n",
"\n",
"# Set your Hugging Face Token\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
"\n",
"response = completion(\n",
" model=\"huggingface/sambanova/meta-llama/Llama-3.1-8B-Instruct\", messages=messages, tools=tools, tool_choice=\"auto\"\n",
")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hugging Face Dedicated Inference Endpoints\n",
"\n",
"Steps to use\n",
"\n",
"- Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
"- Set `api_base` to your deployed api base\n",
"- set the model to `huggingface/tgi` so that litellm knows it's a huggingface Deployed Inference Endpoint.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import litellm\n",
"\n",
"\n",
"response = litellm.completion(\n",
" model=\"huggingface/tgi\",\n",
" messages=[{\"content\": \"Hello, how are you?\", \"role\": \"user\"}],\n",
" api_base=\"https://my-endpoint.endpoints.huggingface.cloud/v1/\",\n",
")\n",
"print(response)"
]
}
],
"metadata": {
@ -251,7 +230,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
@ -264,7 +244,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.0"
}
},
"nbformat": 4,

View file

@ -2,466 +2,392 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Huggingface
# Hugging Face
LiteLLM supports running inference across multiple services for models hosted on the Hugging Face Hub.
LiteLLM supports the following types of Hugging Face models:
- **Serverless Inference Providers** - Hugging Face offers an easy and unified access to serverless AI inference through multiple inference providers, like [Together AI](https://together.ai) and [Sambanova](https://sambanova.ai). This is the fastest way to integrate AI in your products with a maintenance-free and scalable solution. More details in the [Inference Providers documentation](https://huggingface.co/docs/inference-providers/index).
- **Dedicated Inference Endpoints** - which is a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. You can deploy your model on Hugging Face Inference Endpoints by following [these steps](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint).
- Serverless Inference API (free) - loaded and ready to use: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation
- Dedicated Inference Endpoints (paid) - manual deployment: https://ui.endpoints.huggingface.co/
- All LLMs served via Hugging Face's Inference use [Text-generation-inference](https://huggingface.co/docs/text-generation-inference).
## Supported Models
### Serverless Inference Providers
You can check available models for an inference provider by going to [huggingface.co/models](https://huggingface.co/models), clicking the "Other" filter tab, and selecting your desired provider:
![Filter models by Inference Provider](../../img/hf_filter_inference_providers.png)
For example, you can find all Fireworks supported models [here](https://huggingface.co/models?inference_provider=fireworks-ai&sort=trending).
### Dedicated Inference Endpoints
Refer to the [Inference Endpoints catalog](https://endpoints.huggingface.co/catalog) for a list of available models.
## Usage
<Tabs>
<TabItem value="serverless" label="Serverless Inference Providers">
### Authentication
With a single Hugging Face token, you can access inference through multiple providers. Your calls are routed through Hugging Face and the usage is billed directly to your Hugging Face account at the standard provider API rates.
Simply set the `HF_TOKEN` environment variable with your Hugging Face token, you can create one here: https://huggingface.co/settings/tokens.
```bash
export HF_TOKEN="hf_xxxxxx"
```
or alternatively, you can pass your Hugging Face token as a parameter:
```python
completion(..., api_key="hf_xxxxxx")
```
### Getting Started
To use a Hugging Face model, specify both the provider and model you want to use in the following format:
```
huggingface/<provider>/<hf_org_or_user>/<hf_model>
```
Where `<hf_org_or_user>/<hf_model>` is the Hugging Face model ID and `<provider>` is the inference provider.
By default, if you don't specify a provider, LiteLLM will use the [HF Inference API](https://huggingface.co/docs/api-inference/en/index).
Examples:
```python
# Run DeepSeek-R1 inference through Together AI
completion(model="huggingface/together/deepseek-ai/DeepSeek-R1",...)
# Run Qwen2.5-72B-Instruct inference through Sambanova
completion(model="huggingface/sambanova/Qwen/Qwen2.5-72B-Instruct",...)
# Run Llama-3.3-70B-Instruct inference through HF Inference API
completion(model="huggingface/meta-llama/Llama-3.3-70B-Instruct",...)
```
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
You need to tell LiteLLM when you're calling Huggingface.
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`.
<Tabs>
<TabItem value="serverless" label="Serverless Inference API">
By default, LiteLLM will assume a Hugging Face call follows the [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api), which is fully compatible with the OpenAI Chat Completion API.
<Tabs>
<TabItem value="sdk" label="SDK">
### Basic Completion
Here's an example of chat completion using the DeepSeek-R1 model through Together AI:
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
os.environ["HF_TOKEN"] = "hf_xxxxxx"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
response = completion(
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}],
model="huggingface/together/deepseek-ai/DeepSeek-R1",
messages=[
{
"role": "user",
"content": "How many r's are in the word 'strawberry'?",
}
],
)
print(response)
```
### Streaming
Now, let's see what a streaming request looks like.
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/together/deepseek-ai/DeepSeek-R1",
messages=[
{
"role": "user",
"content": "How many r's are in the word `strawberry`?",
}
],
stream=True,
)
for chunk in response:
print(chunk)
```
### Image Input
You can also pass images when the model supports it. Here is an example using [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) model through Sambanova.
```python
from litellm import completion
# Set your Hugging Face Token
os.environ["HF_TOKEN"] = "hf_xxxxxx"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
},
],
}
]
response = completion(
model="huggingface/sambanova/meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
)
print(response.choices[0])
```
### Function Calling
You can extend the model's capabilities by giving them access to tools. Here is an example with function calling using [Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) model through Sambanova.
```python
import os
from litellm import completion
# Set your Hugging Face Token
os.environ["HF_TOKEN"] = "hf_xxxxxx"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today?",
}
]
response = completion(
model="huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct",
messages=messages,
tools=tools,
tool_choice="auto"
)
print(response)
```
</TabItem>
<TabItem value="endpoints" label="Inference Endpoints">
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
### Basic Completion
After you have [deployed your Hugging Face Inference Endpoint](https://endpoints.huggingface.co/new) on dedicated infrastructure, you can run inference on it by providing the endpoint base URL in `api_base`, and indicating `huggingface/tgi` as the model name.
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/tgi",
messages=[{"content": "Hello, how are you?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/"
)
print(response)
```
### Streaming
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/tgi",
messages=[{"content": "Hello, how are you?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
stream=True
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3.1-8B-instruct
litellm_params:
model: huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct
api_key: os.environ/HUGGINGFACE_API_KEY
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3.1-8B-instruct",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="dedicated" label="Dedicated Inference Endpoints">
Steps to use
* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/
* Set `api_base` to your deployed api base
* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
os.environ["HUGGINGFACE_API_KEY"] = ""
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
# add the 'huggingface/' prefix to the model to set huggingface as the provider
# set api base to your deployed api endpoint from hugging face
response = completion(
model="huggingface/glaiveai/glaive-coder-7b",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: glaive-coder
litellm_params:
model: huggingface/glaiveai/glaive-coder-7b
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "glaive-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Streaming
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
You need to tell LiteLLM when you're calling Huggingface.
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`.
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion(
model="huggingface/facebook/blenderbot-400M-distill",
messages=messages,
api_base="https://my-endpoint.huggingface.cloud",
stream=True
)
print(response)
for chunk in response:
print(chunk)
print(chunk)
```
### Image Input
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
},
],
}
]
response = completion(
model="huggingface/tgi",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/""
)
print(response.choices[0])
```
### Function Calling
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
functions = [{
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get weather for"
}
},
"required": ["location"]
}
}]
response = completion(
model="huggingface/tgi",
messages=[{"content": "What's the weather like in San Francisco?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
functions=functions
)
print(response)
```
</TabItem>
</Tabs>
## LiteLLM Proxy Server with Hugging Face models
You can set up a [LiteLLM Proxy Server](https://docs.litellm.ai/#litellm-proxy-server-llm-gateway) to serve Hugging Face models through any of the supported Inference Providers. Here's how to do it:
### Step 1. Setup the config file
In this case, we are configuring a proxy to serve `DeepSeek R1` from Hugging Face, using Together AI as the backend Inference Provider.
```yaml
model_list:
- model_name: my-r1-model
litellm_params:
model: huggingface/together/deepseek-ai/DeepSeek-R1
api_key: os.environ/HF_TOKEN # ensure you have `HF_TOKEN` in your .env
```
### Step 2. Start the server
```bash
litellm --config /path/to/config.yaml
```
### Step 3. Make a request to the server
<Tabs>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "my-r1-model",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}'
```
</TabItem>
<TabItem value="python" label="python">
```python
# pip install openai
from openai import OpenAI
client = OpenAI(
base_url="http://0.0.0.0:4000",
api_key="anything",
)
response = client.chat.completions.create(
model="my-r1-model",
messages=[
{"role": "user", "content": "Hello, how are you?"}
]
)
print(response)
```
</TabItem>
</Tabs>
## Embedding
LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) format.
LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) models as well.
```python
from litellm import embedding
import os
os.environ['HUGGINGFACE_API_KEY'] = ""
os.environ['HF_TOKEN'] = "hf_xxxxxx"
response = embedding(
model='huggingface/microsoft/codebert-base',
input=["good morning from litellm"]
)
```
## Advanced
### Setting API KEYS + API BASE
If required, you can set the api key + api base, set it in your os environment. [Code for how it's sent](https://github.com/BerriAI/litellm/blob/0100ab2382a0e720c7978fbf662cc6e6920e7e03/litellm/llms/huggingface_restapi.py#L25)
```python
import os
os.environ["HUGGINGFACE_API_KEY"] = ""
os.environ["HUGGINGFACE_API_BASE"] = ""
```
### Viewing Log probs
#### Using `decoder_input_details` - OpenAI `echo`
The `echo` param is supported by OpenAI Completions - Use `litellm.text_completion()` for this
```python
from litellm import text_completion
response = text_completion(
model="huggingface/bigcode/starcoder",
prompt="good morning",
max_tokens=10, logprobs=10,
echo=True
)
```
#### Output
```json
{
"id": "chatcmpl-3fc71792-c442-4ba1-a611-19dd0ac371ad",
"object": "text_completion",
"created": 1698801125.936519,
"model": "bigcode/starcoder",
"choices": [
{
"text": ", I'm going to make you a sand",
"index": 0,
"logprobs": {
"tokens": [
"good",
" morning",
",",
" I",
"'m",
" going",
" to",
" make",
" you",
" a",
" s",
"and"
],
"token_logprobs": [
"None",
-14.96875,
-2.2285156,
-2.734375,
-2.0957031,
-2.0917969,
-0.09429932,
-3.1132812,
-1.3203125,
-1.2304688,
-1.6201172,
-0.010292053
]
},
"finish_reason": "length"
}
],
"usage": {
"completion_tokens": 9,
"prompt_tokens": 2,
"total_tokens": 11
}
}
```
### Models with Prompt Formatting
For models with special prompt templates (e.g. Llama2), we format the prompt to fit their template.
#### Models with natively Supported Prompt Templates
| Model Name | Works for Models | Function Call | Required OS Variables |
| ------------------------------------ | ---------------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ----------------------------------- |
| mistralai/Mistral-7B-Instruct-v0.1 | mistralai/Mistral-7B-Instruct-v0.1 | `completion(model='huggingface/mistralai/Mistral-7B-Instruct-v0.1', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| meta-llama/Llama-2-7b-chat | All meta-llama llama2 chat models | `completion(model='huggingface/meta-llama/Llama-2-7b', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| tiiuae/falcon-7b-instruct | All falcon instruct models | `completion(model='huggingface/tiiuae/falcon-7b-instruct', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| mosaicml/mpt-7b-chat | All mpt chat models | `completion(model='huggingface/mosaicml/mpt-7b-chat', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| codellama/CodeLlama-34b-Instruct-hf | All codellama instruct models | `completion(model='huggingface/codellama/CodeLlama-34b-Instruct-hf', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| WizardLM/WizardCoder-Python-34B-V1.0 | All wizardcoder models | `completion(model='huggingface/WizardLM/WizardCoder-Python-34B-V1.0', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| Phind/Phind-CodeLlama-34B-v2 | All phind-codellama models | `completion(model='huggingface/Phind/Phind-CodeLlama-34B-v2', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
**What if we don't support a model you need?**
You can also specify you're own custom prompt formatting, in case we don't have your model covered yet.
**Does this mean you have to specify a prompt for all models?**
No. By default we'll concatenate your message content to make a prompt.
**Default Prompt Template**
```python
def default_pt(messages):
return " ".join(message["content"] for message in messages)
```
[Code for how prompt formats work in LiteLLM](https://github.com/BerriAI/litellm/blob/main/litellm/llms/prompt_templates/factory.py)
#### Custom prompt templates
```python
import litellm
# Create your own custom prompt template works
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
},
"user": {
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
},
"assistant": {
"post_message": "\n"
}
}
)
def test_huggingface_custom_model():
model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud")
print(response['choices'][0]['message']['content'])
return response
test_huggingface_custom_model()
```
[Implementation Code](https://github.com/BerriAI/litellm/blob/c0b3da2c14c791a0b755f0b1e5a9ef065951ecbf/litellm/llms/huggingface_restapi.py#L52)
### Deploying a model on huggingface
You can use any chat/text model from Hugging Face with the following steps:
- Copy your model id/url from Huggingface Inference Endpoints
- [ ] Go to https://ui.endpoints.huggingface.co/
- [ ] Copy the url of the specific model you'd like to use
<Image img={require('../../img/hf_inference_endpoint.png')} alt="HF_Dashboard" style={{ maxWidth: '50%', height: 'auto' }}/>
- Set it as your model name
- Set your HUGGINGFACE_API_KEY as an environment variable
Need help deploying a model on huggingface? [Check out this guide.](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
# output
Same as the OpenAI format, but also includes logprobs. [See the code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/llms/huggingface_restapi.py#L115)
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "\ud83d\ude31\n\nComment: @SarahSzabo I'm",
"role": "assistant",
"logprobs": -22.697942825499993
}
}
],
"created": 1693436637.38206,
"model": "https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud",
"usage": {
"prompt_tokens": 14,
"completion_tokens": 11,
"total_tokens": 25
}
}
```
# FAQ
**Does this support stop sequences?**
**How does billing work with Hugging Face Inference Providers?**
Yes, we support stop sequences - and you can pass as many as allowed by Hugging Face (or any provider!)
> Billing is centralized on your Hugging Face account, no matter which providers you are using. You are billed the standard provider API rates with no additional markup - Hugging Face simply passes through the provider costs. Note that [Hugging Face PRO](https://huggingface.co/subscribe/pro) users get $2 worth of Inference credits every month that can be used across providers.
**How do you deal with repetition penalty?**
**Do I need to create an account for each Inference Provider?**
We map the presence penalty parameter in openai to the repetition penalty parameter on Hugging Face. [See code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/utils.py#L757).
> No, you don't need to create separate accounts. All requests are routed through Hugging Face, so you only need your HF token. This allows you to easily benchmark different providers and choose the one that best fits your needs.
We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)!
**Will more inference providers be supported by Hugging Face in the future?**
> Yes! New inference providers (and models) are being added gradually.
We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)!

View file

@ -406,6 +406,7 @@ router_settings:
| HELICONE_API_KEY | API key for Helicone service
| HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog)
| HUGGINGFACE_API_BASE | Base URL for Hugging Face API
| HUGGINGFACE_API_KEY | API key for Hugging Face API
| IAM_TOKEN_DB_AUTH | IAM token for database authentication
| JSON_LOGS | Enable JSON formatted logging
| JWT_AUDIENCE | Expected audience for JWT tokens

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

View file

@ -800,9 +800,8 @@ from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig
from .llms.github.chat.transformation import GithubChatConfig
from .llms.empower.chat.transformation import EmpowerChatConfig
from .llms.huggingface.chat.transformation import (
HuggingfaceChatConfig as HuggingfaceConfig,
)
from .llms.huggingface.chat.transformation import HuggingFaceChatConfig
from .llms.huggingface.embedding.transformation import HuggingFaceEmbeddingConfig
from .llms.oobabooga.chat.transformation import OobaboogaConfig
from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig

View file

@ -120,7 +120,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "replicate":
return litellm.ReplicateConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params(model=model)
return litellm.HuggingFaceChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "jina_ai":
if request_type == "embeddings":
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()

View file

@ -355,15 +355,6 @@ class LiteLLMResponseObjectHandler:
Only supported for HF TGI models
"""
transformed_logprobs: Optional[TextCompletionLogprobs] = None
if custom_llm_provider == "huggingface":
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.huggingface._transform_logprobs(
hf_response=raw_response
)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
return transformed_logprobs

View file

@ -214,10 +214,7 @@ class CustomStreamWrapper:
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
"""
hold = False
if (
self.custom_llm_provider != "huggingface"
and self.custom_llm_provider != "sagemaker"
):
if self.custom_llm_provider != "sagemaker":
return hold, chunk
if finish_reason:
@ -290,49 +287,6 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def handle_huggingface_chunk(self, chunk):
try:
if not isinstance(chunk, str):
chunk = chunk.decode(
"utf-8"
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
text = ""
is_finished = False
finish_reason = ""
print_verbose(f"chunk: {chunk}")
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
if data_json.get("details", False) and data_json["details"].get(
"finish_reason", False
):
is_finished = True
finish_reason = data_json["details"]["finish_reason"]
elif data_json.get(
"generated_text", False
): # if full generated text exists, then stream is complete
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in chunk:
raise ValueError(chunk)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
raise e
def handle_ai21_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
@ -1049,11 +1003,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
response_obj = self.handle_predibase_chunk(chunk)
completion_obj["content"] = response_obj["text"]

View file

@ -1,769 +0,0 @@
## Uses the huggingface text generation inference API
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
get_args,
)
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.huggingface.chat.transformation import (
HuggingfaceChatConfig as HuggingfaceConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.types.utils import ModelResponse
from ...base import BaseLLM
from ..common_utils import HuggingfaceError
hf_chat_config = HuggingfaceConfig()
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise HuggingfaceError(
status_code=e.response.status_code,
message=str(await e.response.aread()),
headers=cast(dict, error_headers) if error_headers else None,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise HuggingfaceError(status_code=500, message=str(e))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return response.aiter_lines(), response.headers
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
api_base: Optional[str],
model_response: ModelResponse,
print_verbose: Callable,
timeout: float,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion: bool = False,
logger_fn=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers: dict = {},
):
super().completion()
exception_mapping_worked = False
try:
task, model = hf_chat_config.get_hf_task_for_model(model)
litellm_params["task"] = task
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
data = hf_chat_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(
api_base=completion_url,
data=data,
headers=headers,
model_response=model_response,
encoding=encoding,
model=model,
optional_params=optional_params,
timeout=timeout,
litellm_params=litellm_params,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
client=(
client
if client is not None
and isinstance(client, AsyncHTTPHandler)
else None
),
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] is True:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return hf_chat_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except httpx.HTTPStatusError as e:
raise HuggingfaceError(
status_code=e.response.status_code,
message=e.response.text,
headers=e.response.headers,
)
except HuggingfaceError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
encoding: Any,
model: str,
optional_params: dict,
litellm_params: dict,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: str,
messages: List[AllMessageValues],
client: Optional[AsyncHTTPHandler] = None,
):
response: Optional[httpx.Response] = None
try:
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE
)
### ASYNC COMPLETION
http_response = await client.post(
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
)
response = http_response
return hf_chat_config.transform_response(
model=model,
raw_response=http_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif isinstance(e, HuggingfaceError):
raise e
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
headers=response.headers,
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
messages: List[AllMessageValues],
model: str,
timeout: float,
client: Optional[AsyncHTTPHandler] = None,
):
completion_stream, _ = await make_call(
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=False,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
return streamwrapper
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = HuggingfaceConfig().get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> EmbeddingResponse:
super().embedding()
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def _transform_logprobs(
self, hf_response: Optional[List]
) -> Optional[TextCompletionLogprobs]:
"""
Transform Hugging Face logprobs to OpenAI.Completion() format
"""
if hf_response is None:
return None
# Initialize an empty list for the transformed logprobs
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
text_offset=[],
token_logprobs=[],
tokens=[],
top_logprobs=[],
)
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
cast(List[int], _logprob.text_offset).append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
return _logprob

View file

@ -1,27 +1,10 @@
import json
import logging
import os
import time
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
@ -30,176 +13,98 @@ if TYPE_CHECKING:
else:
LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
tgi_models_cache = None
conv_models_cache = None
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
class HuggingfaceChatConfig(BaseConfig):
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
class HuggingFaceChatConfig(OpenAIGPTConfig):
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
"""
hf_task: Optional[
hf_tasks
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
def validate_environment(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
headers: dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = f"Bearer {api_key}"
return optional_params
headers = {**headers, **default_headers}
def get_hf_api_key(self) -> Optional[str]:
return get_secret_str("HUGGINGFACE_API_KEY")
return headers
def read_tgi_conv_models(self):
try:
global tgi_models_cache, conv_models_cache
# Check if the cache is already populated
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
Get the API base for the Huggingface API.
# Cache the set for future use
tgi_models_cache = tgi_models
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if model.startswith(("http://", "https://")):
base_url = model
elif base_url is None:
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
return base_url
# If not, read the file and populate the cache
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except Exception:
return set(), set()
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
For provider-specific routing through huggingface
"""
# 1. Check if api_base is provided
if api_base is not None:
complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
elif model.startswith(("http://", "https://")):
complete_url = model
# 4. Default construction with provider
else:
return "text-generation-inference", model # default to tgi
# Parse provider and model
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
else:
provider = "hf-inference"
if provider == "hf-inference":
route = f"{provider}/models/{model}/v1/chat/completions"
elif provider == "novita":
route = f"{provider}/chat/completions"
else:
route = f"{provider}/v1/chat/completions"
complete_url = f"{BASE_URL}/{route}"
# Ensure URL doesn't end with a slash
complete_url = complete_url.rstrip("/")
return complete_url
def transform_request(
self,
@ -209,381 +114,28 @@ class HuggingfaceChatConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
task = litellm_params.get("task", None)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
## Load Config
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
if "max_retries" in optional_params:
logger.warning("`max_retries` is not supported. It will be ignored.")
optional_params.pop("max_retries", None)
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
model_id = remaining
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingfaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
provider = "hf-inference"
model_id = model
provider_mapping = _fetch_inference_provider_mapping(model_id)
if provider not in provider_mapping:
raise HuggingFaceError(
message=f"Model {model_id} is not supported for provider {provider}",
status_code=404,
headers={},
)
else:
if (
isinstance(completion_response, list)
and len(completion_response[0]["generated_text"]) > 0
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
provider_mapping = provider_mapping[provider]
if provider_mapping["status"] == "staging":
logger.warning(
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
)
else:
## LOGGING
logging_obj.post_call(
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingfaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingfaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))

View file

@ -1,18 +1,30 @@
import os
from functools import lru_cache
from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
HF_HUB_URL = "https://huggingface.co"
class HuggingfaceError(BaseLLMException):
class HuggingFaceError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
hf_tasks = Literal[
@ -43,3 +55,48 @@ def output_parser(generated_text: str):
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
@lru_cache(maxsize=128)
def _fetch_inference_provider_mapping(model: str) -> dict:
"""
Fetch provider mappings for a model from the Hugging Face Hub.
Args:
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
Returns:
dict: The inference provider mapping for the model
Raises:
ValueError: If no provider mapping is found
HuggingFaceError: If the API request fails
"""
headers = {"Accept": "application/json"}
if os.getenv("HUGGINGFACE_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
path = f"{HF_HUB_URL}/api/models/{model}"
params = {"expand": ["inferenceProviderMapping"]}
try:
response = httpx.get(path, headers=headers, params=params)
response.raise_for_status()
provider_mapping = response.json().get("inferenceProviderMapping")
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
except httpx.HTTPError as e:
if hasattr(e, "response"):
status_code = getattr(e.response, "status_code", 500)
headers = getattr(e.response, "headers", {})
else:
status_code = 500
headers = {}
raise HuggingFaceError(
message=f"Failed to fetch provider mapping: {str(e)}",
status_code=status_code,
headers=headers,
)

View file

@ -0,0 +1,421 @@
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from ...base import BaseLLM
from ..common_utils import HuggingFaceError
from .transformation import HuggingFaceEmbeddingConfig
config = HuggingFaceEmbeddingConfig()
HF_HUB_URL = "https://huggingface.co"
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class HuggingFaceEmbedding(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = config.get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingFaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> EmbeddingResponse:
super().embedding()
headers = config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
)
task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View file

@ -0,0 +1,589 @@
import json
import os
import time
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingFaceError, hf_task_list, hf_tasks, output_parser
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
tgi_models_cache = None
conv_models_cache = None
class HuggingFaceEmbeddingConfig(BaseConfig):
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[
hf_tasks
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return get_secret_str("HUGGINGFACE_API_KEY")
def read_tgi_conv_models(self):
try:
global tgi_models_cache, conv_models_cache
# Check if the cache is already populated
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except Exception:
return set(), set()
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
else:
return "text-generation-inference", model # default to tgi
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
task = litellm_params.get("task", None)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
## Load Config
config = litellm.HuggingFaceEmbeddingConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingFaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
)
else:
if (
isinstance(completion_response, list)
and len(completion_response[0]["generated_text"]) > 0
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
)
else:
## LOGGING
logging_obj.post_call(
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingFaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingFaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)

View file

@ -402,4 +402,4 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
choices=chunk["choices"],
)
except Exception as e:
raise e
raise e

View file

@ -141,7 +141,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import aleph_alpha, palm
from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface.chat.handler import Huggingface
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.ollama.completion import handler as ollama
from .llms.oobabooga.chat import oobabooga
@ -221,7 +221,7 @@ azure_chat_completions = AzureChatCompletion()
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
azure_text_completions = AzureTextCompletion()
azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface()
huggingface_embed = HuggingFaceEmbedding()
predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
bedrock_converse_chat_completion = BedrockConverseLLM()
@ -2141,7 +2141,6 @@ def completion( # type: ignore # noqa: PLR0915
response = model_response
elif custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
huggingface_key = (
api_key
or litellm.huggingface_key
@ -2150,40 +2149,23 @@ def completion( # type: ignore # noqa: PLR0915
or litellm.api_key
)
hf_headers = headers or litellm.headers
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = huggingface.completion(
response = base_llm_http_handler.completion(
model=model,
messages=messages,
api_base=api_base, # type: ignore
headers=hf_headers or {},
headers=hf_headers,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=huggingface_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout, # type: ignore
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="huggingface",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "oobabooga":
custom_llm_provider = "oobabooga"
model_response = oobabooga.completion(
@ -3623,7 +3605,7 @@ def embedding( # noqa: PLR0915
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
) # type: ignore
response = huggingface.embedding(
response = huggingface_embed.embedding(
model=model,
input=input,
encoding=encoding, # type: ignore

View file

@ -3225,7 +3225,7 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "huggingface":
optional_params = litellm.HuggingfaceConfig().map_openai_params(
optional_params = litellm.HuggingFaceChatConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
@ -6270,7 +6270,7 @@ class ProviderConfigManager:
elif litellm.LlmProviders.REPLICATE == provider:
return litellm.ReplicateConfig()
elif litellm.LlmProviders.HUGGINGFACE == provider:
return litellm.HuggingfaceConfig()
return litellm.HuggingFaceChatConfig()
elif litellm.LlmProviders.TOGETHER_AI == provider:
return litellm.TogetherAIConfig()
elif litellm.LlmProviders.OPENROUTER == provider:

View file

@ -1,169 +0,0 @@
"""
Unit Tests Huggingface route
"""
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock
import pytest
def tgi_mock_post(url, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
@pytest.fixture
def huggingface_chat_completion_call():
def _call(
model="huggingface/my-test-model",
messages=None,
api_key="test_api_key",
headers=None,
client=None,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()
mock_response = Mock()
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post
return _call
@pytest.fixture
def async_huggingface_chat_completion_call():
async def _call(
model="huggingface/my-test-model",
messages=None,
api_key="test_api_key",
headers=None,
client=None,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = AsyncHTTPHandler()
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
await acompletion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post
return _call
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_huggingface_chat_completions_endpoint(
sync_mode, huggingface_chat_completion_call, async_huggingface_chat_completion_call
):
model = "huggingface/another-model"
messages = [{"role": "user", "content": "Test message"}]
if sync_mode:
mock_post = huggingface_chat_completion_call(model=model, messages=messages)
else:
mock_post = await async_huggingface_chat_completion_call(
model=model, messages=messages
)
assert mock_post.call_count == 1

View file

@ -0,0 +1,358 @@
"""
Test HuggingFace LLM
"""
from re import M
import httpx
from base_llm_unit_tests import BaseLLMChatTest
import json
import os
import sys
from unittest.mock import patch, MagicMock, AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import pytest
from litellm.types.utils import ModelResponseStream, ModelResponse
from respx import MockRouter
MOCK_COMPLETION_RESPONSE = {
"id": "9115d3daeab10608",
"object": "chat.completion",
"created": 11111,
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"prompt": [],
"choices": [
{
"finish_reason": "stop",
"seed": 3629048360264764400,
"logprobs": None,
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response from the mocked HuggingFace API.",
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
MOCK_STREAMING_CHUNKS = [
{"id": "id1", "object": "chat.completion.chunk", "created": 1111,
"choices": [{"index": 0, "text": "Deep", "logprobs": None, "finish_reason": None, "seed": None,
"delta": {"token_id": 34564, "role": "assistant", "content": "Deep", "tool_calls": None}}],
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo", "usage": None},
{"id": "id2", "object": "chat.completion.chunk", "created": 1111,
"choices": [{"index": 0, "text": " learning", "logprobs": None, "finish_reason": None, "seed": None,
"delta": {"token_id": 6975, "role": "assistant", "content": " learning", "tool_calls": None}}],
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo", "usage": None},
{"id": "id3", "object": "chat.completion.chunk", "created": 1111,
"choices": [{"index": 0, "text": " is", "logprobs": None, "finish_reason": None, "seed": None,
"delta": {"token_id": 374, "role": "assistant", "content": " is", "tool_calls": None}}],
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo", "usage": None},
{"id": "sid4", "object": "chat.completion.chunk", "created": 1111,
"choices": [{"index": 0, "text": " response", "logprobs": None, "finish_reason": "length", "seed": 2853637492034609700,
"delta": {"token_id": 323, "role": "assistant", "content": " response", "tool_calls": None}}],
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
"usage": {"prompt_tokens": 26, "completion_tokens": 20, "total_tokens": 46}}
]
PROVIDER_MAPPING_RESPONSE = {
"fireworks-ai": {
"status": "live",
"providerId": "accounts/fireworks/models/llama-v3-8b-instruct",
"task": "conversational"
},
"together": {
"status": "live",
"providerId": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
"task": "conversational"
},
"hf-inference": {
"status": "live",
"providerId": "meta-llama/Meta-Llama-3-8B-Instruct",
"task": "conversational"
},
}
@pytest.fixture
def mock_provider_mapping():
with patch("litellm.llms.huggingface.chat.transformation._fetch_inference_provider_mapping") as mock:
mock.return_value = PROVIDER_MAPPING_RESPONSE
yield mock
@pytest.fixture(autouse=True)
def clear_lru_cache():
from litellm.llms.huggingface.common_utils import _fetch_inference_provider_mapping
_fetch_inference_provider_mapping.cache_clear()
yield
_fetch_inference_provider_mapping.cache_clear()
@pytest.fixture
def mock_http_handler():
"""Fixture to mock the HTTP handler"""
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
) as mock:
print(f"Creating mock HTTP handler: {mock}")
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.status_code = 200
def mock_side_effect(*args, **kwargs):
if kwargs.get("stream", True):
mock_response.iter_lines.return_value = iter([
f"data: {json.dumps(chunk)}".encode('utf-8')
for chunk in MOCK_STREAMING_CHUNKS
] + [b'data: [DONE]'])
else:
mock_response.json.return_value = MOCK_COMPLETION_RESPONSE
return mock_response
mock.side_effect = mock_side_effect
yield mock
@pytest.fixture
def mock_http_async_handler():
"""Fixture to mock the async HTTP handler"""
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock
) as mock:
print(f"Creating mock async HTTP handler: {mock}")
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = MOCK_COMPLETION_RESPONSE
mock_response.text = json.dumps(MOCK_COMPLETION_RESPONSE)
async def mock_side_effect(*args, **kwargs):
if kwargs.get("stream", True):
async def mock_aiter():
for chunk in MOCK_STREAMING_CHUNKS:
yield f"data: {json.dumps(chunk)}".encode('utf-8')
yield b"data: [DONE]"
mock_response.aiter_lines = mock_aiter
return mock_response
mock.side_effect = mock_side_effect
yield mock
class TestHuggingFace(BaseLLMChatTest):
@pytest.fixture(autouse=True)
def setup(self, mock_provider_mapping, mock_http_handler, mock_http_async_handler):
self.mock_provider_mapping = mock_provider_mapping
self.mock_http = mock_http_handler
self.mock_http_async = mock_http_async_handler
self.model = "huggingface/together/meta-llama/Meta-Llama-3-8B-Instruct"
litellm.set_verbose = False
def get_base_completion_call_args(self) -> dict:
"""Implementation of abstract method from BaseLLMChatTest"""
return {"model": self.model}
def test_completion_non_streaming(self):
messages = [{"role": "user", "content": "This is a dummy message"}]
response = litellm.completion(
model=self.model,
messages=messages,
stream=False
)
assert isinstance(response, ModelResponse)
assert response.choices[0].message.content == "This is a test response from the mocked HuggingFace API."
assert response.usage is not None
assert response.model == self.model.split("/",2)[2]
def test_completion_streaming(self):
messages = [{"role": "user", "content": "This is a dummy message"}]
response = litellm.completion(
model=self.model,
messages=messages,
stream=True
)
chunks = list(response)
assert len(chunks) > 0
assert self.mock_http.called
call_args = self.mock_http.call_args
assert call_args is not None
kwargs = call_args[1]
data = json.loads(kwargs["data"])
assert data["stream"] is True
assert data["messages"] == messages
assert isinstance(chunks, list)
assert isinstance(chunks[0], ModelResponseStream)
assert isinstance(chunks[0].id, str)
assert chunks[0].model == self.model.split("/",1)[1]
@pytest.mark.asyncio
async def test_async_completion_streaming(self):
"""Test async streaming completion"""
messages = [{"role": "user", "content": "This is a dummy message"}]
response = await litellm.acompletion(
model=self.model,
messages=messages,
stream=True
)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert self.mock_http_async.called
assert len(chunks) > 0
assert isinstance(chunks[0], ModelResponseStream)
assert isinstance(chunks[0].id, str)
assert chunks[0].model == self.model.split("/",1)[1]
@pytest.mark.asyncio
async def test_async_completion_non_streaming(self):
"""Test async non-streaming completion"""
messages = [{"role": "user", "content": "This is a dummy message"}]
response = await litellm.acompletion(
model=self.model,
messages=messages,
stream=False
)
assert self.mock_http_async.called
assert isinstance(response, ModelResponse)
assert response.choices[0].message.content == "This is a test response from the mocked HuggingFace API."
assert response.usage is not None
assert response.model == self.model.split("/",2)[2]
def test_tool_call_no_arguments(self, tool_call_no_arguments):
mock_tool_response = {
**MOCK_COMPLETION_RESPONSE,
"choices": [{
"finish_reason": "tool_calls",
"index": 0,
"message": tool_call_no_arguments
}]
}
with patch.object(self.mock_http, "side_effect", lambda *args, **kwargs: MagicMock(
status_code=200,
json=lambda: mock_tool_response,
raise_for_status=lambda: None
)):
messages = [{"role": "user", "content": "Get the FAQ"}]
tools = [{
"type": "function",
"function": {
"name": "Get-FAQ",
"description": "Get FAQ information",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}]
response = litellm.completion(
model=self.model,
messages=messages,
tools=tools,
tool_choice="auto"
)
assert response.choices[0].message.tool_calls is not None
assert len(response.choices[0].message.tool_calls) == 1
assert response.choices[0].message.tool_calls[0].function.name == tool_call_no_arguments["tool_calls"][0]["function"]["name"]
assert response.choices[0].message.tool_calls[0].function.arguments == tool_call_no_arguments["tool_calls"][0]["function"]["arguments"]
@pytest.mark.parametrize(
"model, provider, expected_url",
[
("meta-llama/Llama-3-8B-Instruct", None, "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3-8B-Instruct/v1/chat/completions"),
("together/meta-llama/Llama-3-8B-Instruct", None, "https://router.huggingface.co/together/v1/chat/completions"),
("novita/meta-llama/Llama-3-8B-Instruct", None, "https://router.huggingface.co/novita/chat/completions"),
("http://custom-endpoint.com/v1/chat/completions", None, "http://custom-endpoint.com/v1/chat/completions"),
],
)
def test_get_complete_url(self, model, provider, expected_url):
"""Test that the complete URL is constructed correctly for different providers"""
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
config = HuggingFaceChatConfig()
url = config.get_complete_url(
api_base=None,
model=model,
optional_params={},
stream=False,
api_key="test_api_key",
litellm_params={}
)
assert url == expected_url
def test_validate_environment(self):
"""Test that the environment is validated correctly"""
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
config = HuggingFaceChatConfig()
headers = config.validate_environment(
headers={},
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}],
optional_params={},
api_key="test_api_key"
)
assert headers["Authorization"] == "Bearer test_api_key"
assert headers["content-type"] == "application/json"
@pytest.mark.parametrize(
"model, expected_model",
[
("together/meta-llama/Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct-Turbo"),
("meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct"),
],
)
def test_transform_request(self, model, expected_model):
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
config = HuggingFaceChatConfig()
messages = [{"role": "user", "content": "Hello"}]
transformed_request = config.transform_request(
model=model,
messages=messages,
optional_params={},
litellm_params={},
headers={}
)
assert transformed_request["model"] == expected_model
assert transformed_request["messages"] == messages
@pytest.mark.asyncio
async def test_completion_cost(self):
pass

View file

@ -169,18 +169,6 @@ def test_all_model_configs():
drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.huggingface.chat.handler import HuggingfaceConfig
assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params(
model="llama3"
)
assert HuggingfaceConfig().map_openai_params(
non_default_params={"max_completion_tokens": 10},
optional_params={},
model="llama3",
drop_params=False,
) == {"max_new_tokens": 10}
from litellm.llms.nvidia_nim.chat import NvidiaNimConfig
assert "max_completion_tokens" in NvidiaNimConfig().get_supported_openai_params(

View file

@ -64,28 +64,6 @@ def test_convert_chat_to_text_completion():
)
def test_convert_provider_response_logprobs():
"""Test converting provider logprobs to text completion logprobs"""
response = ModelResponse(
id="test123",
_hidden_params={
"original_response": {
"details": {"tokens": [{"text": "hello", "logprob": -1.0}]}
}
},
)
result = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
response=response, custom_llm_provider="huggingface"
)
# Note: The actual assertion here depends on the implementation of
# litellm.huggingface._transform_logprobs, but we can at least test the function call
assert (
result is not None or result is None
) # Will depend on the actual implementation
def test_convert_provider_response_logprobs_non_huggingface():
"""Test converting provider logprobs for non-huggingface provider"""
response = ModelResponse(id="test123", _hidden_params={})

View file

@ -1575,148 +1575,6 @@ HF Tests we should pass
"""
#####################################################
#####################################################
# Test util to sort models to TGI, conv, None
from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b"
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "meta-llama/Llama-2-7b-hf"
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "facebook/blenderbot-400M-distill"
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
model = "facebook/blenderbot-3B"
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
# neither Conv or None
model = "roneneldan/TinyStories-3M"
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation"
# test_get_hf_task_for_model()
# litellm.set_verbose=False
# ################### Hugging Face TGI models ########################
# # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def tgi_mock_post(url, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
client = HTTPHandler()
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
wait_for_model=True,
client=client,
)
mock_client.assert_called_once()
# Add any assertions-here to check the response
print(response)
assert "options" in mock_client.call_args.kwargs["data"]
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert "wait_for_model" in json_data["options"]
assert json_data["options"]["wait_for_model"] is True
except litellm.ServiceUnavailableError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# hf_test_completion_tgi()
@pytest.mark.parametrize(
"provider", ["openai", "hosted_vllm", "lm_studio"]
@ -1866,26 +1724,6 @@ def mock_post(url, **kwargs):
return mock_response
def test_hf_classifier_task():
try:
client = HTTPHandler()
with patch.object(client, "post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
client=client,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
def test_ollama_image():
"""

View file

@ -643,8 +643,8 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.asyncio
@patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model")
@patch("litellm.llms.huggingface.chat.handler.get_hf_task_embedding_for_model")
@patch("litellm.llms.huggingface.embedding.handler.async_get_hf_task_embedding_for_model")
@patch("litellm.llms.huggingface.embedding.handler.get_hf_task_embedding_for_model")
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_hf_embedding_sentence_sim(
mock_async_get_hf_task_embedding_for_model,

View file

@ -370,7 +370,7 @@ def test_get_model_info_huggingface_models(monkeypatch):
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
"litellm_params": {
"model": "huggingface/meta-llama/Meta-Llama-3-8B-Instruct",
"api_base": "https://api-inference.huggingface.co/models/meta-llama/Llama-3.3-70B-Instruct",
"api_base": "https://router.huggingface.co/hf-inference/models/meta-llama/Meta-Llama-3-8B-Instruct",
"api_key": os.environ["HUGGINGFACE_API_KEY"],
},
}

View file

@ -1,75 +0,0 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
def test_prompt_formatting():
try:
prompt = prompt_factory(
model="mistralai/Mistral-7B-Instruct-v0.1",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
)
assert (
prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
)
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
def test_prompt_formatting_custom_model():
try:
prompt = prompt_factory(
model="ehartford/dolphin-2.5-mixtral-8x7b",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
custom_llm_provider="huggingface",
)
print(f"prompt: {prompt}")
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_prompt_formatting_custom_model()
# def logger_fn(user_model_dict):
# return
# print(f"user_model_dict: {user_model_dict}")
# messages=[{"role": "user", "content": "Write me a function to print hello world"}]
# # test if the first-party prompt templates work
# def test_huggingface_supported_models():
# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0"
# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
# print(response['choices'][0]['message']['content'])
# return response
# test_huggingface_supported_models()
# # test if a custom prompt template works
# litellm.register_prompt_template(
# model="togethercomputer/LLaMA-2-7B-32K",
# roles={"system":"", "assistant":"Assistant:", "user":"User:"},
# pre_message_sep= "\n",
# post_message_sep= "\n"
# )
# def test_huggingface_custom_model():
# model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
# print(response['choices'][0]['message']['content'])
# return response
# test_huggingface_custom_model()