Merge branch 'main' into fix_budget_limits

This commit is contained in:
Sam 2025-04-07 14:38:08 +10:00 committed by GitHub
commit 3f8b827f79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
290 changed files with 10057 additions and 4647 deletions

View file

@ -10,6 +10,6 @@ anthropic
orjson==3.9.15 orjson==3.9.15
pydantic==2.10.2 pydantic==2.10.2
google-cloud-aiplatform==1.43.0 google-cloud-aiplatform==1.43.0
fastapi-sso==0.10.0 fastapi-sso==0.16.0
uvloop==0.21.0 uvloop==0.21.0
mcp==1.5.0 # for MCP server mcp==1.5.0 # for MCP server

View file

@ -6,8 +6,9 @@
"id": "9dKM5k8qsMIj" "id": "9dKM5k8qsMIj"
}, },
"source": [ "source": [
"## LiteLLM HuggingFace\n", "## LiteLLM Hugging Face\n",
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface" "\n",
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface\n"
] ]
}, },
{ {
@ -27,23 +28,18 @@
"id": "yp5UXRqtpu9f" "id": "yp5UXRqtpu9f"
}, },
"source": [ "source": [
"## Hugging Face Free Serverless Inference API\n", "## Serverless Inference Providers\n",
"Read more about the Free Serverless Inference API here: https://huggingface.co/docs/api-inference.\n",
"\n", "\n",
"In order to use litellm to call Serverless Inference API:\n", "Read more about Inference Providers here: https://huggingface.co/blog/inference-providers.\n",
"\n", "\n",
"* Browse Serverless Inference compatible models here: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation.\n", "In order to use litellm with Hugging Face Inference Providers, you need to set `model=huggingface/<provider>/<model-id>`.\n",
"* Copy the model name from hugging face\n",
"* Set `model = \"huggingface/<model-name>\"`\n",
"\n", "\n",
"Example set `model=huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct` to call `meta-llama/Meta-Llama-3.1-8B-Instruct`\n", "Example: `huggingface/together/deepseek-ai/DeepSeek-R1` to run DeepSeek-R1 (https://huggingface.co/deepseek-ai/DeepSeek-R1) through Together AI.\n"
"\n",
"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@ -51,107 +47,18 @@
"id": "Pi5Oww8gpCUm", "id": "Pi5Oww8gpCUm",
"outputId": "659a67c7-f90d-4c06-b94e-2c4aa92d897a" "outputId": "659a67c7-f90d-4c06-b94e-2c4aa92d897a"
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelResponse(id='chatcmpl-c54dfb68-1491-4d68-a4dc-35e603ea718a', choices=[Choices(finish_reason='eos_token', index=0, message=Message(content=\"I'm just a computer program, so I don't have feelings, but thank you for asking! How can I assist you today?\", role='assistant', tool_calls=None, function_call=None))], created=1724858285, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=27, prompt_tokens=47, total_tokens=74))\n",
"ModelResponse(id='chatcmpl-d2ae38e6-4974-431c-bb9b-3fa3f95e5a6d', choices=[Choices(finish_reason='length', index=0, message=Message(content=\"\\n\\nIm doing well, thank you. Ive been keeping busy with work and some personal projects. How about you?\\n\\nI'm doing well, thank you. I've been enjoying some time off and catching up on some reading. How can I assist you today?\\n\\nI'm looking for a good book to read. Do you have any recommendations?\\n\\nOf course! Here are a few book recommendations across different genres:\\n\\n1.\", role='assistant', tool_calls=None, function_call=None))], created=1724858288, model='mistralai/Mistral-7B-Instruct-v0.3', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=85, prompt_tokens=6, total_tokens=91))\n"
]
}
],
"source": [ "source": [
"import os\n", "import os\n",
"import litellm\n", "from litellm import completion\n",
"\n", "\n",
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n", "# You can create a HF token here: https://huggingface.co/settings/tokens\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n", "os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n", "\n",
"# Call https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct\n", "# Call DeepSeek-R1 model through Together AI\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n", "response = completion(\n",
"response = litellm.completion(\n", " model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " messages=[{\"content\": \"How many r's are in the word `strawberry`?\", \"role\": \"user\"}],\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
")\n",
"print(response)\n",
"\n",
"\n",
"# Call https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3\n",
"response = litellm.completion(\n",
" model=\"huggingface/mistralai/Mistral-7B-Instruct-v0.3\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-klhAhjLtclv"
},
"source": [
"## Hugging Face Dedicated Inference Endpoints\n",
"\n",
"Steps to use\n",
"* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
"* Set `api_base` to your deployed api base\n",
"* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Lbmw8Gl_pHns",
"outputId": "ea8408bf-1cc3-4670-ecea-f12666d204a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"object\": \"chat.completion\",\n",
" \"choices\": [\n",
" {\n",
" \"finish_reason\": \"length\",\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"content\": \"\\n\\nI am doing well, thank you for asking. How about you?\\nI am doing\",\n",
" \"role\": \"assistant\",\n",
" \"logprobs\": -8.9481967812\n",
" }\n",
" }\n",
" ],\n",
" \"id\": \"chatcmpl-74dc9d89-3916-47ce-9bea-b80e66660f77\",\n",
" \"created\": 1695871068.8413374,\n",
" \"model\": \"glaiveai/glaive-coder-7b\",\n",
" \"usage\": {\n",
" \"prompt_tokens\": 6,\n",
" \"completion_tokens\": 18,\n",
" \"total_tokens\": 24\n",
" }\n",
"}\n"
]
}
],
"source": [
"import os\n",
"import litellm\n",
"\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
"\n",
"# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
"# set api base to your deployed api endpoint from hugging face\n",
"response = litellm.completion(\n",
" model=\"huggingface/glaiveai/glaive-coder-7b\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n",
" api_base=\"https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud\"\n",
")\n", ")\n",
"print(response)" "print(response)"
] ]
@ -162,13 +69,12 @@
"id": "EU0UubrKzTFe" "id": "EU0UubrKzTFe"
}, },
"source": [ "source": [
"## HuggingFace - Streaming (Serveless or Dedicated)\n", "## Streaming\n"
"Set stream = True"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@ -176,74 +82,147 @@
"id": "y-QfIvA-uJKX", "id": "y-QfIvA-uJKX",
"outputId": "b007bb98-00d0-44a4-8264-c8a2caed6768" "outputId": "b007bb98-00d0-44a4-8264-c8a2caed6768"
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"<litellm.utils.CustomStreamWrapper object at 0x1278471d0>\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='I', role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'m\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' just', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' a', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' computer', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' program', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' so', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' don', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'t\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' have', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' feelings', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' but', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' thank', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' for', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' asking', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='!', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' How', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' can', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' assist', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' today', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='?', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='<|eot_id|>', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n"
]
}
],
"source": [ "source": [
"import os\n", "import os\n",
"import litellm\n", "from litellm import completion\n",
"\n", "\n",
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n", "os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
"\n", "\n",
"# Call https://huggingface.co/glaiveai/glaive-coder-7b\n", "response = completion(\n",
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n", " model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
"# set api base to your deployed api endpoint from hugging face\n", " messages=[\n",
"response = litellm.completion(\n", " {\n",
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"role\": \"user\",\n",
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", " \"content\": \"How many r's are in the word `strawberry`?\",\n",
" stream=True\n", " \n",
" }\n",
" ],\n",
" stream=True,\n",
")\n", ")\n",
"\n", "\n",
"print(response)\n",
"\n",
"for chunk in response:\n", "for chunk in response:\n",
" print(chunk)" " print(chunk)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## With images as input\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"id": "CKXAnK55zQRl"
},
"outputs": [], "outputs": [],
"source": [] "source": [
"from litellm import completion\n",
"\n",
"# Set your Hugging Face Token\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"text\", \"text\": \"What's in this image?\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\",\n",
" },\n",
" },\n",
" ],\n",
" }\n",
"]\n",
"\n",
"response = completion(\n",
" model=\"huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct\",\n",
" messages=messages,\n",
")\n",
"print(response.choices[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tools - Function Calling\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from litellm import completion\n",
"\n",
"\n",
"# Set your Hugging Face Token\n",
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
"\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
"\n",
"response = completion(\n",
" model=\"huggingface/sambanova/meta-llama/Llama-3.1-8B-Instruct\", messages=messages, tools=tools, tool_choice=\"auto\"\n",
")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hugging Face Dedicated Inference Endpoints\n",
"\n",
"Steps to use\n",
"\n",
"- Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
"- Set `api_base` to your deployed api base\n",
"- set the model to `huggingface/tgi` so that litellm knows it's a huggingface Deployed Inference Endpoint.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import litellm\n",
"\n",
"\n",
"response = litellm.completion(\n",
" model=\"huggingface/tgi\",\n",
" messages=[{\"content\": \"Hello, how are you?\", \"role\": \"user\"}],\n",
" api_base=\"https://my-endpoint.endpoints.huggingface.cloud/v1/\",\n",
")\n",
"print(response)"
]
} }
], ],
"metadata": { "metadata": {
@ -251,7 +230,8 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": ".venv",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
@ -264,7 +244,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.12.2" "version": "3.12.0"
} }
}, },
"nbformat": 4, "nbformat": 4,

View file

@ -27,16 +27,18 @@ os.environ["AWS_REGION_NAME"] = ""
# pdf url # pdf url
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" file_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
# model # model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [ file_content = [
{"type": "text", "text": "What's this file about?"}, {"type": "text", "text": "What's this file about?"},
{ {
"type": "image_url", "type": "file",
"image_url": image_url, # OR {"url": image_url} "file": {
"file_id": file_url,
}
}, },
] ]
@ -46,7 +48,7 @@ if not supports_pdf_input(model, None):
response = completion( response = completion(
model=model, model=model,
messages=[{"role": "user", "content": image_content}], messages=[{"role": "user", "content": file_content}],
) )
assert response is not None assert response is not None
``` ```
@ -80,11 +82,15 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-d '{ -d '{
"model": "bedrock-model", "model": "bedrock-model",
"messages": [ "messages": [
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}}, {"role": "user", "content": [
{ {"type": "text", "text": "What's this file about?"},
"type": "image_url", {
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf", "type": "file",
} "file": {
"file_id": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
}
}
]},
] ]
}' }'
``` ```
@ -116,11 +122,13 @@ base64_url = f"data:application/pdf;base64,{encoded_file}"
# model # model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [ file_content = [
{"type": "text", "text": "What's this file about?"}, {"type": "text", "text": "What's this file about?"},
{ {
"type": "image_url", "type": "file",
"image_url": base64_url, # OR {"url": base64_url} "file": {
"file_data": base64_url,
}
}, },
] ]
@ -130,11 +138,53 @@ if not supports_pdf_input(model, None):
response = completion( response = completion(
model=model, model=model,
messages=[{"role": "user", "content": image_content}], messages=[{"role": "user", "content": file_content}],
) )
assert response is not None assert response is not None
``` ```
</TabItem> </TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-model
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start the proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "bedrock-model",
"messages": [
{"role": "user", "content": [
{"type": "text", "text": "What's this file about?"},
{
"type": "file",
"file": {
"file_data": "data:application/pdf;base64...",
}
}
]},
]
}'
```
</TabItem>
</Tabs> </Tabs>
## Checking if a model supports pdf input ## Checking if a model supports pdf input
@ -200,92 +250,3 @@ Expected Response
</TabItem> </TabItem>
</Tabs> </Tabs>
## OpenAI 'file' message type
This is currently only supported for OpenAI models.
This will be supported for all providers soon.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import base64
from litellm import completion
with open("draconomicon.pdf", "rb") as f:
data = f.read()
base64_string = base64.b64encode(data).decode("utf-8")
completion = completion(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{
"type": "file",
"file": {
"filename": "draconomicon.pdf",
"file_data": f"data:application/pdf;base64,{base64_string}",
}
},
{
"type": "text",
"text": "What is the first dragon in the book?",
}
],
},
],
)
print(completion.choices[0].message.content)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: openai-model
litellm_params:
model: gpt-4o
api_key: os.environ/OPENAI_API_KEY
```
2. Start the proxy
```bash
litellm --config config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "openai-model",
"messages": [
{"role": "user", "content": [
{
"type": "file",
"file": {
"filename": "draconomicon.pdf",
"file_data": f"data:application/pdf;base64,{base64_string}",
}
}
]}
]
}'
```
</TabItem>
</Tabs>

View file

@ -108,3 +108,75 @@ response = litellm.completion(
</Tabs> </Tabs>
**additional_drop_params**: List or null - Is a list of openai params you want to drop when making a call to the model. **additional_drop_params**: List or null - Is a list of openai params you want to drop when making a call to the model.
## Specify allowed openai params in a request
Tell litellm to allow specific openai params in a request. Use this if you get a `litellm.UnsupportedParamsError` and want to allow a param. LiteLLM will pass the param as is to the model.
<Tabs>
<TabItem value="sdk" label="LiteLLM Python SDK">
In this example we pass `allowed_openai_params=["tools"]` to allow the `tools` param.
```python showLineNumbers title="Pass allowed_openai_params to LiteLLM Python SDK"
await litellm.acompletion(
model="azure/o_series/<my-deployment-name>",
api_key="xxxxx",
api_base=api_base,
messages=[{"role": "user", "content": "Hello! return a json object"}],
tools=[{"type": "function", "function": {"name": "get_current_time", "description": "Get the current time in a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name, e.g. San Francisco"}}, "required": ["location"]}}}]
allowed_openai_params=["tools"],
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM Proxy">
When using litellm proxy you can pass `allowed_openai_params` in two ways:
1. Dynamically pass `allowed_openai_params` in a request
2. Set `allowed_openai_params` on the config.yaml file for a specific model
#### Dynamically pass allowed_openai_params in a request
In this example we pass `allowed_openai_params=["tools"]` to allow the `tools` param for a request sent to the model set on the proxy.
```python showLineNumbers title="Dynamically pass allowed_openai_params in a request"
import openai
from openai import AsyncAzureOpenAI
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={
"allowed_openai_params": ["tools"]
}
)
```
#### Set allowed_openai_params on config.yaml
You can also set `allowed_openai_params` on the config.yaml file for a specific model. This means that all requests to this deployment are allowed to pass in the `tools` param.
```yaml showLineNumbers title="Set allowed_openai_params on config.yaml"
model_list:
- model_name: azure-o1-preview
litellm_params:
model: azure/o_series/<my-deployment-name>
api_key: xxxxx
api_base: https://openai-prod-test.openai.azure.com/openai/deployments/o1/chat/completions?api-version=2025-01-01-preview
allowed_openai_params: ["tools"]
```
</TabItem>
</Tabs>

View file

@ -1076,32 +1076,24 @@ print(response)
``` ```
### Parallel Function calling ### Tool Calling / Function Calling
See a detailed walthrough of parallel function calling with litellm [here](https://docs.litellm.ai/docs/completion/function_call) See a detailed walthrough of parallel function calling with litellm [here](https://docs.litellm.ai/docs/completion/function_call)
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
# set Azure env variables # set Azure env variables
import os import os
import litellm
import json
os.environ['AZURE_API_KEY'] = "" # litellm reads AZURE_API_KEY from .env and sends the request os.environ['AZURE_API_KEY'] = "" # litellm reads AZURE_API_KEY from .env and sends the request
os.environ['AZURE_API_BASE'] = "https://openai-gpt-4-test-v-1.openai.azure.com/" os.environ['AZURE_API_BASE'] = "https://openai-gpt-4-test-v-1.openai.azure.com/"
os.environ['AZURE_API_VERSION'] = "2023-07-01-preview" os.environ['AZURE_API_VERSION'] = "2023-07-01-preview"
import litellm
import json
# Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"})
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else:
return json.dumps({"location": location, "temperature": "unknown"})
## Step 1: send the conversation and available functions to the model
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -1125,7 +1117,7 @@ tools = [
response = litellm.completion( response = litellm.completion(
model="azure/chatgpt-functioncalling", # model = azure/<your-azure-deployment-name> model="azure/chatgpt-functioncalling", # model = azure/<your-azure-deployment-name>
messages=messages, messages=[{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}],
tools=tools, tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit tool_choice="auto", # auto is default, but we'll be explicit
) )
@ -1134,8 +1126,49 @@ response_message = response.choices[0].message
tool_calls = response.choices[0].message.tool_calls tool_calls = response.choices[0].message.tool_calls
print("\nTool Choice:\n", tool_calls) print("\nTool Choice:\n", tool_calls)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: azure-gpt-3.5
litellm_params:
model: azure/chatgpt-functioncalling
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
```
2. Start proxy
```bash
litellm --config config.yaml
```
3. Test it
```bash
curl -L -X POST 'http://localhost:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "azure-gpt-3.5",
"messages": [
{
"role": "user",
"content": "Hey, how'\''s it going? Thinking long and hard before replying - what is the meaning of the world and life itself"
}
]
}'
```
</TabItem>
</Tabs>
### Spend Tracking for Azure OpenAI Models (PROXY) ### Spend Tracking for Azure OpenAI Models (PROXY)
Set base model for cost tracking azure image-gen call Set base model for cost tracking azure image-gen call

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🆕 Databricks # Databricks
LiteLLM supports all models on Databricks LiteLLM supports all models on Databricks
@ -154,7 +154,205 @@ response = completion(
temperature: 0.5 temperature: 0.5
``` ```
## Passings Databricks specific params - 'instruction'
## Usage - Thinking / `reasoning_content`
LiteLLM translates OpenAI's `reasoning_effort` to Anthropic's `thinking` parameter. [Code](https://github.com/BerriAI/litellm/blob/23051d89dd3611a81617d84277059cd88b2df511/litellm/llms/anthropic/chat/transformation.py#L298)
| reasoning_effort | thinking |
| ---------------- | -------- |
| "low" | "budget_tokens": 1024 |
| "medium" | "budget_tokens": 2048 |
| "high" | "budget_tokens": 4096 |
Known Limitations:
- Support for passing thinking blocks back to Claude [Issue](https://github.com/BerriAI/litellm/issues/9790)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
# set ENV variables (can also be passed in to .completion() - e.g. `api_base`, `api_key`)
os.environ["DATABRICKS_API_KEY"] = "databricks key"
os.environ["DATABRICKS_API_BASE"] = "databricks base url"
resp = completion(
model="databricks/databricks-claude-3-7-sonnet",
messages=[{"role": "user", "content": "What is the capital of France?"}],
reasoning_effort="low",
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
- model_name: claude-3-7-sonnet
litellm_params:
model: databricks/databricks-claude-3-7-sonnet
api_key: os.environ/DATABRICKS_API_KEY
api_base: os.environ/DATABRICKS_API_BASE
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "claude-3-7-sonnet",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"reasoning_effort": "low"
}'
```
</TabItem>
</Tabs>
**Expected Response**
```python
ModelResponse(
id='chatcmpl-c542d76d-f675-4e87-8e5f-05855f5d0f5e',
created=1740470510,
model='claude-3-7-sonnet-20250219',
object='chat.completion',
system_fingerprint=None,
choices=[
Choices(
finish_reason='stop',
index=0,
message=Message(
content="The capital of France is Paris.",
role='assistant',
tool_calls=None,
function_call=None,
provider_specific_fields={
'citations': None,
'thinking_blocks': [
{
'type': 'thinking',
'thinking': 'The capital of France is Paris. This is a very straightforward factual question.',
'signature': 'EuYBCkQYAiJAy6...'
}
]
}
),
thinking_blocks=[
{
'type': 'thinking',
'thinking': 'The capital of France is Paris. This is a very straightforward factual question.',
'signature': 'EuYBCkQYAiJAy6AGB...'
}
],
reasoning_content='The capital of France is Paris. This is a very straightforward factual question.'
)
],
usage=Usage(
completion_tokens=68,
prompt_tokens=42,
total_tokens=110,
completion_tokens_details=None,
prompt_tokens_details=PromptTokensDetailsWrapper(
audio_tokens=None,
cached_tokens=0,
text_tokens=None,
image_tokens=None
),
cache_creation_input_tokens=0,
cache_read_input_tokens=0
)
)
```
### Pass `thinking` to Anthropic models
You can also pass the `thinking` parameter to Anthropic models.
You can also pass the `thinking` parameter to Anthropic models.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
# set ENV variables (can also be passed in to .completion() - e.g. `api_base`, `api_key`)
os.environ["DATABRICKS_API_KEY"] = "databricks key"
os.environ["DATABRICKS_API_BASE"] = "databricks base url"
response = litellm.completion(
model="databricks/databricks-claude-3-7-sonnet",
messages=[{"role": "user", "content": "What is the capital of France?"}],
thinking={"type": "enabled", "budget_tokens": 1024},
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LITELLM_KEY" \
-d '{
"model": "databricks/databricks-claude-3-7-sonnet",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"thinking": {"type": "enabled", "budget_tokens": 1024}
}'
```
</TabItem>
</Tabs>
## Supported Databricks Chat Completion Models
:::tip
**We support ALL Databricks models, just set `model=databricks/<any-model-on-databricks>` as a prefix when sending litellm requests**
:::
| Model Name | Command |
|----------------------------|------------------------------------------------------------------|
| databricks/databricks-claude-3-7-sonnet | `completion(model='databricks/databricks/databricks-claude-3-7-sonnet', messages=messages)` |
| databricks-meta-llama-3-1-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-70b-instruct', messages=messages)` |
| databricks-meta-llama-3-1-405b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-405b-instruct', messages=messages)` |
| databricks-dbrx-instruct | `completion(model='databricks/databricks-dbrx-instruct', messages=messages)` |
| databricks-meta-llama-3-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-70b-instruct', messages=messages)` |
| databricks-llama-2-70b-chat | `completion(model='databricks/databricks-llama-2-70b-chat', messages=messages)` |
| databricks-mixtral-8x7b-instruct | `completion(model='databricks/databricks-mixtral-8x7b-instruct', messages=messages)` |
| databricks-mpt-30b-instruct | `completion(model='databricks/databricks-mpt-30b-instruct', messages=messages)` |
| databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` |
## Embedding Models
### Passing Databricks specific params - 'instruction'
For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164) For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164)
@ -187,27 +385,6 @@ response = litellm.embedding(
instruction: "Represent this sentence for searching relevant passages:" instruction: "Represent this sentence for searching relevant passages:"
``` ```
## Supported Databricks Chat Completion Models
:::tip
**We support ALL Databricks models, just set `model=databricks/<any-model-on-databricks>` as a prefix when sending litellm requests**
:::
| Model Name | Command |
|----------------------------|------------------------------------------------------------------|
| databricks-meta-llama-3-1-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-70b-instruct', messages=messages)` |
| databricks-meta-llama-3-1-405b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-405b-instruct', messages=messages)` |
| databricks-dbrx-instruct | `completion(model='databricks/databricks-dbrx-instruct', messages=messages)` |
| databricks-meta-llama-3-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-70b-instruct', messages=messages)` |
| databricks-llama-2-70b-chat | `completion(model='databricks/databricks-llama-2-70b-chat', messages=messages)` |
| databricks-mixtral-8x7b-instruct | `completion(model='databricks/databricks-mixtral-8x7b-instruct', messages=messages)` |
| databricks-mpt-30b-instruct | `completion(model='databricks/databricks-mpt-30b-instruct', messages=messages)` |
| databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` |
## Supported Databricks Embedding Models ## Supported Databricks Embedding Models
:::tip :::tip

View file

@ -887,3 +887,54 @@ response = await client.chat.completions.create(
</TabItem> </TabItem>
</Tabs> </Tabs>
## Image Generation
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
response = completion(
model="gemini/gemini-2.0-flash-exp-image-generation",
messages=[{"role": "user", "content": "Generate an image of a cat"}],
modalities=["image", "text"],
)
assert response.choices[0].message.content is not None # ".."
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash-exp-image-generation
litellm_params:
model: gemini/gemini-2.0-flash-exp-image-generation
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://localhost:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash-exp-image-generation",
"messages": [{"role": "user", "content": "Generate an image of a cat"}],
"modalities": ["image", "text"]
}'
```
</TabItem>
</Tabs>

View file

@ -0,0 +1,161 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Google AI Studio (Gemini) Files API
Use this to upload files to Google AI Studio (Gemini).
Useful to pass in large media files to Gemini's `/generateContent` endpoint.
| Action | Supported |
|----------|-----------|
| `create` | Yes |
| `delete` | No |
| `retrieve` | No |
| `list` | No |
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import base64
import requests
from litellm import completion, create_file
import os
### UPLOAD FILE ###
# Fetch the audio file and convert it to a base64 encoded string
url = "https://cdn.openai.com/API/docs/audio/alloy.wav"
response = requests.get(url)
response.raise_for_status()
wav_data = response.content
encoded_string = base64.b64encode(wav_data).decode('utf-8')
file = create_file(
file=wav_data,
purpose="user_data",
extra_body={"custom_llm_provider": "gemini"},
api_key=os.getenv("GEMINI_API_KEY"),
)
print(f"file: {file}")
assert file is not None
### GENERATE CONTENT ###
completion = completion(
model="gemini-2.0-flash",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this recording?"
},
{
"type": "file",
"file": {
"file_id": file.id,
"filename": "my-test-name",
"format": "audio/wav"
}
}
]
},
]
)
print(completion.choices[0].message)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: "gemini-2.0-flash"
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config config.yaml
```
3. Test it
```python
import base64
import requests
from openai import OpenAI
client = OpenAI(
base_url="http://0.0.0.0:4000",
api_key="sk-1234"
)
# Fetch the audio file and convert it to a base64 encoded string
url = "https://cdn.openai.com/API/docs/audio/alloy.wav"
response = requests.get(url)
response.raise_for_status()
wav_data = response.content
encoded_string = base64.b64encode(wav_data).decode('utf-8')
file = client.files.create(
file=wav_data,
purpose="user_data",
extra_body={"target_model_names": "gemini-2.0-flash"}
)
print(f"file: {file}")
assert file is not None
completion = client.chat.completions.create(
model="gemini-2.0-flash",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "wav"},
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this recording?"
},
{
"type": "file",
"file": {
"file_id": file.id,
"filename": "my-test-name",
"format": "audio/wav"
}
}
]
},
],
extra_body={"drop_params": True}
)
print(completion.choices[0].message)
```
</TabItem>
</Tabs>

View file

@ -2,466 +2,392 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Huggingface # Hugging Face
LiteLLM supports running inference across multiple services for models hosted on the Hugging Face Hub.
LiteLLM supports the following types of Hugging Face models: - **Serverless Inference Providers** - Hugging Face offers an easy and unified access to serverless AI inference through multiple inference providers, like [Together AI](https://together.ai) and [Sambanova](https://sambanova.ai). This is the fastest way to integrate AI in your products with a maintenance-free and scalable solution. More details in the [Inference Providers documentation](https://huggingface.co/docs/inference-providers/index).
- **Dedicated Inference Endpoints** - which is a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. You can deploy your model on Hugging Face Inference Endpoints by following [these steps](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint).
- Serverless Inference API (free) - loaded and ready to use: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation
- Dedicated Inference Endpoints (paid) - manual deployment: https://ui.endpoints.huggingface.co/ ## Supported Models
- All LLMs served via Hugging Face's Inference use [Text-generation-inference](https://huggingface.co/docs/text-generation-inference).
### Serverless Inference Providers
You can check available models for an inference provider by going to [huggingface.co/models](https://huggingface.co/models), clicking the "Other" filter tab, and selecting your desired provider:
![Filter models by Inference Provider](../../img/hf_filter_inference_providers.png)
For example, you can find all Fireworks supported models [here](https://huggingface.co/models?inference_provider=fireworks-ai&sort=trending).
### Dedicated Inference Endpoints
Refer to the [Inference Endpoints catalog](https://endpoints.huggingface.co/catalog) for a list of available models.
## Usage ## Usage
<Tabs>
<TabItem value="serverless" label="Serverless Inference Providers">
### Authentication
With a single Hugging Face token, you can access inference through multiple providers. Your calls are routed through Hugging Face and the usage is billed directly to your Hugging Face account at the standard provider API rates.
Simply set the `HF_TOKEN` environment variable with your Hugging Face token, you can create one here: https://huggingface.co/settings/tokens.
```bash
export HF_TOKEN="hf_xxxxxx"
```
or alternatively, you can pass your Hugging Face token as a parameter:
```python
completion(..., api_key="hf_xxxxxx")
```
### Getting Started
To use a Hugging Face model, specify both the provider and model you want to use in the following format:
```
huggingface/<provider>/<hf_org_or_user>/<hf_model>
```
Where `<hf_org_or_user>/<hf_model>` is the Hugging Face model ID and `<provider>` is the inference provider.
By default, if you don't specify a provider, LiteLLM will use the [HF Inference API](https://huggingface.co/docs/api-inference/en/index).
Examples:
```python
# Run DeepSeek-R1 inference through Together AI
completion(model="huggingface/together/deepseek-ai/DeepSeek-R1",...)
# Run Qwen2.5-72B-Instruct inference through Sambanova
completion(model="huggingface/sambanova/Qwen/Qwen2.5-72B-Instruct",...)
# Run Llama-3.3-70B-Instruct inference through HF Inference API
completion(model="huggingface/meta-llama/Llama-3.3-70B-Instruct",...)
```
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> </a>
You need to tell LiteLLM when you're calling Huggingface. ### Basic Completion
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`. Here's an example of chat completion using the DeepSeek-R1 model through Together AI:
<Tabs>
<TabItem value="serverless" label="Serverless Inference API">
By default, LiteLLM will assume a Hugging Face call follows the [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api), which is fully compatible with the OpenAI Chat Completion API.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
# [OPTIONAL] set env var os.environ["HF_TOKEN"] = "hf_xxxxxx"
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
response = completion( response = completion(
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", model="huggingface/together/deepseek-ai/DeepSeek-R1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[
{
"role": "user",
"content": "How many r's are in the word 'strawberry'?",
}
],
)
print(response)
```
### Streaming
Now, let's see what a streaming request looks like.
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/together/deepseek-ai/DeepSeek-R1",
messages=[
{
"role": "user",
"content": "How many r's are in the word `strawberry`?",
}
],
stream=True,
)
for chunk in response:
print(chunk)
```
### Image Input
You can also pass images when the model supports it. Here is an example using [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) model through Sambanova.
```python
from litellm import completion
# Set your Hugging Face Token
os.environ["HF_TOKEN"] = "hf_xxxxxx"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
},
],
}
]
response = completion(
model="huggingface/sambanova/meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
)
print(response.choices[0])
```
### Function Calling
You can extend the model's capabilities by giving them access to tools. Here is an example with function calling using [Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) model through Sambanova.
```python
import os
from litellm import completion
# Set your Hugging Face Token
os.environ["HF_TOKEN"] = "hf_xxxxxx"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today?",
}
]
response = completion(
model="huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct",
messages=messages,
tools=tools,
tool_choice="auto"
)
print(response)
```
</TabItem>
<TabItem value="endpoints" label="Inference Endpoints">
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
### Basic Completion
After you have [deployed your Hugging Face Inference Endpoint](https://endpoints.huggingface.co/new) on dedicated infrastructure, you can run inference on it by providing the endpoint base URL in `api_base`, and indicating `huggingface/tgi` as the model name.
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/tgi",
messages=[{"content": "Hello, how are you?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/"
)
print(response)
```
### Streaming
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
response = completion(
model="huggingface/tgi",
messages=[{"content": "Hello, how are you?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
stream=True stream=True
) )
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3.1-8B-instruct
litellm_params:
model: huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct
api_key: os.environ/HUGGINGFACE_API_KEY
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3.1-8B-instruct",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="dedicated" label="Dedicated Inference Endpoints">
Steps to use
* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/
* Set `api_base` to your deployed api base
* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
os.environ["HUGGINGFACE_API_KEY"] = ""
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
# add the 'huggingface/' prefix to the model to set huggingface as the provider
# set api base to your deployed api endpoint from hugging face
response = completion(
model="huggingface/glaiveai/glaive-coder-7b",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: glaive-coder
litellm_params:
model: huggingface/glaiveai/glaive-coder-7b
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "glaive-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Streaming
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
You need to tell LiteLLM when you're calling Huggingface.
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`.
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion(
model="huggingface/facebook/blenderbot-400M-distill",
messages=messages,
api_base="https://my-endpoint.huggingface.cloud",
stream=True
)
print(response)
for chunk in response: for chunk in response:
print(chunk) print(chunk)
``` ```
### Image Input
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
},
],
}
]
response = completion(
model="huggingface/tgi",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/""
)
print(response.choices[0])
```
### Function Calling
```python
import os
from litellm import completion
os.environ["HF_TOKEN"] = "hf_xxxxxx"
functions = [{
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get weather for"
}
},
"required": ["location"]
}
}]
response = completion(
model="huggingface/tgi",
messages=[{"content": "What's the weather like in San Francisco?", "role": "user"}],
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
functions=functions
)
print(response)
```
</TabItem>
</Tabs>
## LiteLLM Proxy Server with Hugging Face models
You can set up a [LiteLLM Proxy Server](https://docs.litellm.ai/#litellm-proxy-server-llm-gateway) to serve Hugging Face models through any of the supported Inference Providers. Here's how to do it:
### Step 1. Setup the config file
In this case, we are configuring a proxy to serve `DeepSeek R1` from Hugging Face, using Together AI as the backend Inference Provider.
```yaml
model_list:
- model_name: my-r1-model
litellm_params:
model: huggingface/together/deepseek-ai/DeepSeek-R1
api_key: os.environ/HF_TOKEN # ensure you have `HF_TOKEN` in your .env
```
### Step 2. Start the server
```bash
litellm --config /path/to/config.yaml
```
### Step 3. Make a request to the server
<Tabs>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "my-r1-model",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}'
```
</TabItem>
<TabItem value="python" label="python">
```python
# pip install openai
from openai import OpenAI
client = OpenAI(
base_url="http://0.0.0.0:4000",
api_key="anything",
)
response = client.chat.completions.create(
model="my-r1-model",
messages=[
{"role": "user", "content": "Hello, how are you?"}
]
)
print(response)
```
</TabItem>
</Tabs>
## Embedding ## Embedding
LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) format. LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) models as well.
```python ```python
from litellm import embedding from litellm import embedding
import os import os
os.environ['HUGGINGFACE_API_KEY'] = "" os.environ['HF_TOKEN'] = "hf_xxxxxx"
response = embedding( response = embedding(
model='huggingface/microsoft/codebert-base', model='huggingface/microsoft/codebert-base',
input=["good morning from litellm"] input=["good morning from litellm"]
) )
``` ```
## Advanced
### Setting API KEYS + API BASE
If required, you can set the api key + api base, set it in your os environment. [Code for how it's sent](https://github.com/BerriAI/litellm/blob/0100ab2382a0e720c7978fbf662cc6e6920e7e03/litellm/llms/huggingface_restapi.py#L25)
```python
import os
os.environ["HUGGINGFACE_API_KEY"] = ""
os.environ["HUGGINGFACE_API_BASE"] = ""
```
### Viewing Log probs
#### Using `decoder_input_details` - OpenAI `echo`
The `echo` param is supported by OpenAI Completions - Use `litellm.text_completion()` for this
```python
from litellm import text_completion
response = text_completion(
model="huggingface/bigcode/starcoder",
prompt="good morning",
max_tokens=10, logprobs=10,
echo=True
)
```
#### Output
```json
{
"id": "chatcmpl-3fc71792-c442-4ba1-a611-19dd0ac371ad",
"object": "text_completion",
"created": 1698801125.936519,
"model": "bigcode/starcoder",
"choices": [
{
"text": ", I'm going to make you a sand",
"index": 0,
"logprobs": {
"tokens": [
"good",
" morning",
",",
" I",
"'m",
" going",
" to",
" make",
" you",
" a",
" s",
"and"
],
"token_logprobs": [
"None",
-14.96875,
-2.2285156,
-2.734375,
-2.0957031,
-2.0917969,
-0.09429932,
-3.1132812,
-1.3203125,
-1.2304688,
-1.6201172,
-0.010292053
]
},
"finish_reason": "length"
}
],
"usage": {
"completion_tokens": 9,
"prompt_tokens": 2,
"total_tokens": 11
}
}
```
### Models with Prompt Formatting
For models with special prompt templates (e.g. Llama2), we format the prompt to fit their template.
#### Models with natively Supported Prompt Templates
| Model Name | Works for Models | Function Call | Required OS Variables |
| ------------------------------------ | ---------------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ----------------------------------- |
| mistralai/Mistral-7B-Instruct-v0.1 | mistralai/Mistral-7B-Instruct-v0.1 | `completion(model='huggingface/mistralai/Mistral-7B-Instruct-v0.1', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| meta-llama/Llama-2-7b-chat | All meta-llama llama2 chat models | `completion(model='huggingface/meta-llama/Llama-2-7b', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| tiiuae/falcon-7b-instruct | All falcon instruct models | `completion(model='huggingface/tiiuae/falcon-7b-instruct', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| mosaicml/mpt-7b-chat | All mpt chat models | `completion(model='huggingface/mosaicml/mpt-7b-chat', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| codellama/CodeLlama-34b-Instruct-hf | All codellama instruct models | `completion(model='huggingface/codellama/CodeLlama-34b-Instruct-hf', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| WizardLM/WizardCoder-Python-34B-V1.0 | All wizardcoder models | `completion(model='huggingface/WizardLM/WizardCoder-Python-34B-V1.0', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
| Phind/Phind-CodeLlama-34B-v2 | All phind-codellama models | `completion(model='huggingface/Phind/Phind-CodeLlama-34B-v2', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
**What if we don't support a model you need?**
You can also specify you're own custom prompt formatting, in case we don't have your model covered yet.
**Does this mean you have to specify a prompt for all models?**
No. By default we'll concatenate your message content to make a prompt.
**Default Prompt Template**
```python
def default_pt(messages):
return " ".join(message["content"] for message in messages)
```
[Code for how prompt formats work in LiteLLM](https://github.com/BerriAI/litellm/blob/main/litellm/llms/prompt_templates/factory.py)
#### Custom prompt templates
```python
import litellm
# Create your own custom prompt template works
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
},
"user": {
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
},
"assistant": {
"post_message": "\n"
}
}
)
def test_huggingface_custom_model():
model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud")
print(response['choices'][0]['message']['content'])
return response
test_huggingface_custom_model()
```
[Implementation Code](https://github.com/BerriAI/litellm/blob/c0b3da2c14c791a0b755f0b1e5a9ef065951ecbf/litellm/llms/huggingface_restapi.py#L52)
### Deploying a model on huggingface
You can use any chat/text model from Hugging Face with the following steps:
- Copy your model id/url from Huggingface Inference Endpoints
- [ ] Go to https://ui.endpoints.huggingface.co/
- [ ] Copy the url of the specific model you'd like to use
<Image img={require('../../img/hf_inference_endpoint.png')} alt="HF_Dashboard" style={{ maxWidth: '50%', height: 'auto' }}/>
- Set it as your model name
- Set your HUGGINGFACE_API_KEY as an environment variable
Need help deploying a model on huggingface? [Check out this guide.](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
# output
Same as the OpenAI format, but also includes logprobs. [See the code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/llms/huggingface_restapi.py#L115)
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "\ud83d\ude31\n\nComment: @SarahSzabo I'm",
"role": "assistant",
"logprobs": -22.697942825499993
}
}
],
"created": 1693436637.38206,
"model": "https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud",
"usage": {
"prompt_tokens": 14,
"completion_tokens": 11,
"total_tokens": 25
}
}
```
# FAQ # FAQ
**Does this support stop sequences?** **How does billing work with Hugging Face Inference Providers?**
Yes, we support stop sequences - and you can pass as many as allowed by Hugging Face (or any provider!) > Billing is centralized on your Hugging Face account, no matter which providers you are using. You are billed the standard provider API rates with no additional markup - Hugging Face simply passes through the provider costs. Note that [Hugging Face PRO](https://huggingface.co/subscribe/pro) users get $2 worth of Inference credits every month that can be used across providers.
**How do you deal with repetition penalty?** **Do I need to create an account for each Inference Provider?**
We map the presence penalty parameter in openai to the repetition penalty parameter on Hugging Face. [See code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/utils.py#L757). > No, you don't need to create separate accounts. All requests are routed through Hugging Face, so you only need your HF token. This allows you to easily benchmark different providers and choose the one that best fits your needs.
**Will more inference providers be supported by Hugging Face in the future?**
> Yes! New inference providers (and models) are being added gradually.
We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)! We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)!

View file

@ -156,7 +156,7 @@ PROXY_LOGOUT_URL="https://www.google.com"
Set this in your .env (so the proxy can set the correct redirect url) Set this in your .env (so the proxy can set the correct redirect url)
```shell ```shell
PROXY_BASE_URL=https://litellm-api.up.railway.app/ PROXY_BASE_URL=https://litellm-api.up.railway.app
``` ```
#### Step 4. Test flow #### Step 4. Test flow

View file

@ -406,6 +406,7 @@ router_settings:
| HELICONE_API_KEY | API key for Helicone service | HELICONE_API_KEY | API key for Helicone service
| HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog) | HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog)
| HUGGINGFACE_API_BASE | Base URL for Hugging Face API | HUGGINGFACE_API_BASE | Base URL for Hugging Face API
| HUGGINGFACE_API_KEY | API key for Hugging Face API
| IAM_TOKEN_DB_AUTH | IAM token for database authentication | IAM_TOKEN_DB_AUTH | IAM token for database authentication
| JSON_LOGS | Enable JSON formatted logging | JSON_LOGS | Enable JSON formatted logging
| JWT_AUDIENCE | Expected audience for JWT tokens | JWT_AUDIENCE | Expected audience for JWT tokens

View file

@ -6,6 +6,8 @@ import Image from '@theme/IdealImage';
Track spend for keys, users, and teams across 100+ LLMs. Track spend for keys, users, and teams across 100+ LLMs.
LiteLLM automatically tracks spend for all known models. See our [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
### How to Track Spend with LiteLLM ### How to Track Spend with LiteLLM
**Step 1** **Step 1**
@ -35,10 +37,10 @@ response = client.chat.completions.create(
"content": "this is a test request, write a short poem" "content": "this is a test request, write a short poem"
} }
], ],
user="palantir", user="palantir", # OPTIONAL: pass user to track spend by user
extra_body={ extra_body={
"metadata": { "metadata": {
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] "tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
} }
} }
) )
@ -63,9 +65,9 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
"content": "what llm are you" "content": "what llm are you"
} }
], ],
"user": "palantir", "user": "palantir", # OPTIONAL: pass user to track spend by user
"metadata": { "metadata": {
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] "tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
} }
}' }'
``` ```
@ -90,7 +92,7 @@ chat = ChatOpenAI(
user="palantir", user="palantir",
extra_body={ extra_body={
"metadata": { "metadata": {
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] "tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
} }
} }
) )
@ -150,8 +152,134 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin
</TabItem> </TabItem>
</Tabs> </Tabs>
## ✨ (Enterprise) API Endpoints to get Spend ### Allowing Non-Proxy Admins to access `/spend` endpoints
### Getting Spend Reports - To Charge Other Teams, Customers, Users
Use this when you want non-proxy admins to access `/spend` endpoints
:::info
Schedule a [meeting with us to get your Enterprise License](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
##### Create Key
Create Key with with `permissions={"get_spend_routes": true}`
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"permissions": {"get_spend_routes": true}
}'
```
##### Use generated key on `/spend` endpoints
Access spend Routes with newly generate keys
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
-H 'Authorization: Bearer sk-H16BKvrSNConSsBYLGc_7A'
```
#### Reset Team, API Key Spend - MASTER KEY ONLY
Use `/global/spend/reset` if you want to:
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
##### Request
Only the `LITELLM_MASTER_KEY` you set can access this route
```shell
curl -X POST \
'http://localhost:4000/global/spend/reset' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json'
```
##### Expected Responses
```shell
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
```
## Set 'base_model' for Cost Tracking (e.g. Azure deployments)
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
Get the base model name from [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
Example config with `base_model`
```yaml
model_list:
- model_name: azure-gpt-3.5
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview
```
## Daily Spend Breakdown API
Retrieve granular daily usage data for a user (by model, provider, and API key) with a single endpoint.
Example Request:
```shell title="Daily Spend Breakdown API" showLineNumbers
curl -L -X GET 'http://localhost:4000/user/daily/activity?start_date=2025-03-20&end_date=2025-03-27' \
-H 'Authorization: Bearer sk-...'
```
```json title="Daily Spend Breakdown API Response" showLineNumbers
{
"results": [
{
"date": "2025-03-27",
"metrics": {
"spend": 0.0177072,
"prompt_tokens": 111,
"completion_tokens": 1711,
"total_tokens": 1822,
"api_requests": 11
},
"breakdown": {
"models": {
"gpt-4o-mini": {
"spend": 1.095e-05,
"prompt_tokens": 37,
"completion_tokens": 9,
"total_tokens": 46,
"api_requests": 1
},
"providers": { "openai": { ... }, "azure_ai": { ... } },
"api_keys": { "3126b6eaf1...": { ... } }
}
}
],
"metadata": {
"total_spend": 0.7274667,
"total_prompt_tokens": 280990,
"total_completion_tokens": 376674,
"total_api_requests": 14
}
}
```
### API Reference
See our [Swagger API](https://litellm-api.up.railway.app/#/Budget%20%26%20Spend%20Tracking/get_user_daily_activity_user_daily_activity_get) for more details on the `/user/daily/activity` endpoint
## ✨ (Enterprise) Generate Spend Reports
Use this to charge other teams, customers, users
Use the `/global/spend/report` endpoint to get spend reports Use the `/global/spend/report` endpoint to get spend reports
@ -470,105 +598,6 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end
</Tabs> </Tabs>
### Allowing Non-Proxy Admins to access `/spend` endpoints
Use this when you want non-proxy admins to access `/spend` endpoints
:::info
Schedule a [meeting with us to get your Enterprise License](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
##### Create Key
Create Key with with `permissions={"get_spend_routes": true}`
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"permissions": {"get_spend_routes": true}
}'
```
##### Use generated key on `/spend` endpoints
Access spend Routes with newly generate keys
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
-H 'Authorization: Bearer sk-H16BKvrSNConSsBYLGc_7A'
```
#### Reset Team, API Key Spend - MASTER KEY ONLY
Use `/global/spend/reset` if you want to:
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
##### Request
Only the `LITELLM_MASTER_KEY` you set can access this route
```shell
curl -X POST \
'http://localhost:4000/global/spend/reset' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json'
```
##### Expected Responses
```shell
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
```
## Spend Tracking for Azure OpenAI Models
Set base model for cost tracking azure image-gen call
#### Image Generation
```yaml
model_list:
- model_name: dall-e-3
litellm_params:
model: azure/dall-e-3-test
api_version: 2023-06-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY
base_model: dall-e-3 # 👈 set dall-e-3 as base model
model_info:
mode: image_generation
```
#### Chat Completions / Embeddings
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
Get the base model name from [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
Example config with `base_model`
```yaml
model_list:
- model_name: azure-gpt-3.5
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview
```
## Custom Input/Output Pricing
👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models
## ✨ Custom Spend Log metadata ## ✨ Custom Spend Log metadata
@ -588,3 +617,4 @@ Logging specific key,value pairs in spend logs metadata is an enterprise feature
Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags) Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags)
::: :::

View file

@ -0,0 +1,86 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# High Availability Setup (Resolve DB Deadlocks)
Resolve any Database Deadlocks you see in high traffic by using this setup
## What causes the problem?
LiteLLM writes `UPDATE` and `UPSERT` queries to the DB. When using 10+ instances of LiteLLM, these queries can cause deadlocks since each instance could simultaneously attempt to update the same `user_id`, `team_id`, `key` etc.
## How the high availability setup fixes the problem
- All instances will write to a Redis queue instead of the DB.
- A single instance will acquire a lock on the DB and flush the redis queue to the DB.
## How it works
### Stage 1. Each instance writes updates to redis
Each instance will accumlate the spend updates for a key, user, team, etc and write the updates to a redis queue.
<Image img={require('../../img/deadlock_fix_1.png')} style={{ width: '900px', height: 'auto' }} />
<p style={{textAlign: 'left', color: '#666'}}>
Each instance writes updates to redis
</p>
### Stage 2. A single instance flushes the redis queue to the DB
A single instance will acquire a lock on the DB and flush all elements in the redis queue to the DB.
- 1 instance will attempt to acquire the lock for the DB update job
- The status of the lock is stored in redis
- If the instance acquires the lock to write to DB
- It will read all updates from redis
- Aggregate all updates into 1 transaction
- Write updates to DB
- Release the lock
- Note: Only 1 instance can acquire the lock at a time, this limits the number of instances that can write to the DB at once
<Image img={require('../../img/deadlock_fix_2.png')} style={{ width: '900px', height: 'auto' }} />
<p style={{textAlign: 'left', color: '#666'}}>
A single instance flushes the redis queue to the DB
</p>
## Usage
### Required components
- Redis
- Postgres
### Setup on LiteLLM config
You can enable using the redis buffer by setting `use_redis_transaction_buffer: true` in the `general_settings` section of your `proxy_config.yaml` file.
Note: This setup requires litellm to be connected to a redis instance.
```yaml showLineNumbers title="litellm proxy_config.yaml"
general_settings:
use_redis_transaction_buffer: true
litellm_settings:
cache: True
cache_params:
type: redis
supported_call_types: [] # Optional: Set cache for proxy, but not on the actual llm api call
```
## Monitoring
LiteLLM emits the following prometheus metrics to monitor the health/status of the in memory buffer and redis buffer.
| Metric Name | Description | Storage Type |
|-----------------------------------------------------|-----------------------------------------------------------------------------|--------------|
| `litellm_pod_lock_manager_size` | Indicates which pod has the lock to write updates to the database. | Redis |
| `litellm_in_memory_daily_spend_update_queue_size` | Number of items in the in-memory daily spend update queue. These are the aggregate spend logs for each user. | In-Memory |
| `litellm_redis_daily_spend_update_queue_size` | Number of items in the Redis daily spend update queue. These are the aggregate spend logs for each user. | Redis |
| `litellm_in_memory_spend_update_queue_size` | In-memory aggregate spend values for keys, users, teams, team members, etc.| In-Memory |
| `litellm_redis_spend_update_queue_size` | Redis aggregate spend values for keys, users, teams, etc. | Redis |

View file

@ -140,7 +140,7 @@ The above request should not be blocked, and you should receive a regular LLM re
</Tabs> </Tabs>
# Advanced ## Advanced
Aim Guard provides user-specific Guardrail policies, enabling you to apply tailored policies to individual users. Aim Guard provides user-specific Guardrail policies, enabling you to apply tailored policies to individual users.
To utilize this feature, include the end-user's email in the request payload by setting the `x-aim-user-email` header of your request. To utilize this feature, include the end-user's email in the request payload by setting the `x-aim-user-email` header of your request.

View file

@ -177,6 +177,50 @@ export LITELLM_SALT_KEY="sk-1234"
[**See Code**](https://github.com/BerriAI/litellm/blob/036a6821d588bd36d170713dcf5a72791a694178/litellm/proxy/common_utils/encrypt_decrypt_utils.py#L15) [**See Code**](https://github.com/BerriAI/litellm/blob/036a6821d588bd36d170713dcf5a72791a694178/litellm/proxy/common_utils/encrypt_decrypt_utils.py#L15)
## 9. Use `prisma migrate deploy`
Use this to handle db migrations across LiteLLM versions in production
<Tabs>
<TabItem value="env" label="ENV">
```bash
USE_PRISMA_MIGRATE="True"
```
</TabItem>
<TabItem value="cli" label="CLI">
```bash
litellm --use_prisma_migrate
```
</TabItem>
</Tabs>
Benefits:
The migrate deploy command:
- **Does not** issue a warning if an already applied migration is missing from migration history
- **Does not** detect drift (production database schema differs from migration history end state - for example, due to a hotfix)
- **Does not** reset the database or generate artifacts (such as Prisma Client)
- **Does not** rely on a shadow database
### How does LiteLLM handle DB migrations in production?
1. A new migration file is written to our `litellm-proxy-extras` package. [See all](https://github.com/BerriAI/litellm/tree/main/litellm-proxy-extras/litellm_proxy_extras/migrations)
2. The core litellm pip package is bumped to point to the new `litellm-proxy-extras` package. This ensures, older versions of LiteLLM will continue to use the old migrations. [See code](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/pyproject.toml#L58)
3. When you upgrade to a new version of LiteLLM, the migration file is applied to the database. [See code](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/litellm-proxy-extras/litellm_proxy_extras/utils.py#L42)
## Extras ## Extras
### Expected Performance in Production ### Expected Performance in Production

View file

@ -242,6 +242,19 @@ litellm_settings:
| `litellm_redis_fails` | Number of failed redis calls | | `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call | | `litellm_self_latency` | Histogram latency for successful litellm api call |
#### DB Transaction Queue Health Metrics
Use these metrics to monitor the health of the DB Transaction Queue. Eg. Monitoring the size of the in-memory and redis buffers.
| Metric Name | Description | Storage Type |
|-----------------------------------------------------|-----------------------------------------------------------------------------|--------------|
| `litellm_pod_lock_manager_size` | Indicates which pod has the lock to write updates to the database. | Redis |
| `litellm_in_memory_daily_spend_update_queue_size` | Number of items in the in-memory daily spend update queue. These are the aggregate spend logs for each user. | In-Memory |
| `litellm_redis_daily_spend_update_queue_size` | Number of items in the Redis daily spend update queue. These are the aggregate spend logs for each user. | Redis |
| `litellm_in_memory_spend_update_queue_size` | In-memory aggregate spend values for keys, users, teams, team members, etc.| In-Memory |
| `litellm_redis_spend_update_queue_size` | Redis aggregate spend values for keys, users, teams, etc. | Redis |
## **🔥 LiteLLM Maintained Grafana Dashboards ** ## **🔥 LiteLLM Maintained Grafana Dashboards **
@ -268,6 +281,17 @@ Here is a screenshot of the metrics you can monitor with the LiteLLM Grafana Das
## Add authentication on /metrics endpoint
**By default /metrics endpoint is unauthenticated.**
You can opt into running litellm authentication on the /metrics endpoint by setting the following on the config
```yaml
litellm_settings:
require_auth_for_metrics_endpoint: true
```
## FAQ ## FAQ
### What are `_created` vs. `_total` metrics? ### What are `_created` vs. `_total` metrics?

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 325 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 488 KiB

View file

@ -12559,9 +12559,10 @@
} }
}, },
"node_modules/image-size": { "node_modules/image-size": {
"version": "1.1.1", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/image-size/-/image-size-1.1.1.tgz", "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz",
"integrity": "sha512-541xKlUw6jr/6gGuk92F+mYM5zaFAc5ahphvkqvNe2bQ6gVBkd6bfrmVJ2t4KDAfikAYZyIqTnktX3i6/aQDrQ==", "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==",
"license": "MIT",
"dependencies": { "dependencies": {
"queue": "6.0.2" "queue": "6.0.2"
}, },

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -6,7 +6,7 @@ authors:
- name: Krrish Dholakia - name: Krrish Dholakia
title: CEO, LiteLLM title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/ url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer - name: Ishaan Jaffer
title: CTO, LiteLLM title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/ url: https://www.linkedin.com/in/reffajnaahsi/

View file

@ -0,0 +1,176 @@
---
title: v1.65.4-stable
slug: v1.65.4-stable
date: 2025-04-05T10:00:00
authors:
- name: Krrish Dholakia
title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
- name: Ishaan Jaffer
title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/
image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg
tags: []
hide_table_of_contents: false
---
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
## Deploy this version
<Tabs>
<TabItem value="docker" label="Docker">
``` showLineNumbers title="docker run litellm"
docker run
-e STORE_MODEL_IN_DB=True
-p 4000:4000
ghcr.io/berriai/litellm:main-v1.65.4-stable
```
</TabItem>
<TabItem value="pip" label="Pip">
``` showLineNumbers title="pip install litellm"
pip install litellm==1.65.4.post1
```
</TabItem>
</Tabs>
v1.65.4-stable is live. Here are the improvements since v1.65.0-stable.
## Key Highlights
- **Preventing DB Deadlocks**: Fixes a high-traffic issue when multiple instances were writing to the DB at the same time.
- **New Usage Tab**: Enables viewing spend by model and customizing date range
Let's dive in.
### Preventing DB Deadlocks
<Image img={require('../../img/prevent_deadlocks.jpg')} />
This release fixes the DB deadlocking issue that users faced in high traffic (10K+ RPS). This is great because it enables user/key/team spend tracking works at that scale.
Read more about the new architecture [here](https://docs.litellm.ai/docs/proxy/db_deadlocks)
### New Usage Tab
<Image img={require('../../img/release_notes/spend_by_model.jpg')} />
The new Usage tab now brings the ability to track daily spend by model. This makes it easier to catch any spend tracking or token counting errors, when combined with the ability to view successful requests, and token usage.
To test this out, just go to Experimental > New Usage > Activity.
## New Models / Updated Models
1. Databricks - claude-3-7-sonnet cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L10350)
2. VertexAI - `gemini-2.5-pro-exp-03-25` cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4492)
3. VertexAI - `gemini-2.0-flash` cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4689)
4. Groq - add whisper ASR models to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L3324)
5. IBM - Add watsonx/ibm/granite-3-8b-instruct to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L91)
6. Google AI Studio - add gemini/gemini-2.5-pro-preview-03-25 to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4850)
## LLM Translation
1. Vertex AI - Support anyOf param for OpenAI json schema translation [Get Started](https://docs.litellm.ai/docs/providers/vertex#json-schema)
2. Anthropic- response_format + thinking param support (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/reasoning_content)
3. Anthropic - if thinking token is specified and max tokens is not - ensure max token to anthropic is higher than thinking tokens (works across Anthropic API, Bedrock, Vertex) [PR](https://github.com/BerriAI/litellm/pull/9594)
4. Bedrock - latency optimized inference support [Get Started](https://docs.litellm.ai/docs/providers/bedrock#usage---latency-optimized-inference)
5. Sagemaker - handle special tokens + multibyte character code in response [Get Started](https://docs.litellm.ai/docs/providers/aws_sagemaker)
6. MCP - add support for using SSE MCP servers [Get Started](https://docs.litellm.ai/docs/mcp#usage)
8. Anthropic - new `litellm.messages.create` interface for calling Anthropic `/v1/messages` via passthrough [Get Started](https://docs.litellm.ai/docs/anthropic_unified#usage)
11. Anthropic - support file content type in message param (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/providers/anthropic#usage---pdf)
12. Anthropic - map openai 'reasoning_effort' to anthropic 'thinking' param (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/providers/anthropic#usage---thinking--reasoning_content)
13. Google AI Studio (Gemini) - [BETA] `/v1/files` upload support [Get Started](../../docs/providers/google_ai_studio/files)
14. Azure - fix o-series tool calling [Get Started](../../docs/providers/azure#tool-calling--function-calling)
15. Unified file id - [ALPHA] allow calling multiple providers with same file id [PR](https://github.com/BerriAI/litellm/pull/9718)
- This is experimental, and not recommended for production use.
- We plan to have a production-ready implementation by next week.
16. Google AI Studio (Gemini) - return logprobs [PR](https://github.com/BerriAI/litellm/pull/9713)
17. Anthropic - Support prompt caching for Anthropic tool calls [Get Started](https://docs.litellm.ai/docs/completion/prompt_caching)
18. OpenRouter - unwrap extra body on open router calls [PR](https://github.com/BerriAI/litellm/pull/9747)
19. VertexAI - fix credential caching issue [PR](https://github.com/BerriAI/litellm/pull/9756)
20. XAI - filter out 'name' param for XAI [PR](https://github.com/BerriAI/litellm/pull/9761)
21. Gemini - image generation output support [Get Started](../../docs/providers/gemini#image-generation)
22. Databricks - support claude-3-7-sonnet w/ thinking + response_format [Get Started](../../docs/providers/databricks#usage---thinking--reasoning_content)
## Spend Tracking Improvements
1. Reliability fix - Check sent and received model for cost calculation [PR](https://github.com/BerriAI/litellm/pull/9669)
2. Vertex AI - Multimodal embedding cost tracking [Get Started](https://docs.litellm.ai/docs/providers/vertex#multi-modal-embeddings), [PR](https://github.com/BerriAI/litellm/pull/9623)
## Management Endpoints / UI
<Image img={require('../../img/release_notes/new_activity_tab.png')} />
1. New Usage Tab
- Report 'total_tokens' + report success/failure calls
- Remove double bars on scroll
- Ensure daily spend chart ordered from earliest to latest date
- showing spend per model per day
- show key alias on usage tab
- Allow non-admins to view their activity
- Add date picker to new usage tab
2. Virtual Keys Tab
- remove 'default key' on user signup
- fix showing user models available for personal key creation
3. Test Key Tab
- Allow testing image generation models
4. Models Tab
- Fix bulk adding models
- support reusable credentials for passthrough endpoints
- Allow team members to see team models
5. Teams Tab
- Fix json serialization error on update team metadata
6. Request Logs Tab
- Add reasoning_content token tracking across all providers on streaming
7. API
- return key alias on /user/daily/activity [Get Started](../../docs/proxy/cost_tracking#daily-spend-breakdown-api)
8. SSO
- Allow assigning SSO users to teams on MSFT SSO [PR](https://github.com/BerriAI/litellm/pull/9745)
## Logging / Guardrail Integrations
1. Console Logs - Add json formatting for uncaught exceptions [PR](https://github.com/BerriAI/litellm/pull/9619)
2. Guardrails - AIM Guardrails support for virtual key based policies [Get Started](../../docs/proxy/guardrails/aim_security)
3. Logging - fix completion start time tracking [PR](https://github.com/BerriAI/litellm/pull/9688)
4. Prometheus
- Allow adding authentication on Prometheus /metrics endpoints [PR](https://github.com/BerriAI/litellm/pull/9766)
- Distinguish LLM Provider Exception vs. LiteLLM Exception in metric naming [PR](https://github.com/BerriAI/litellm/pull/9760)
- Emit operational metrics for new DB Transaction architecture [PR](https://github.com/BerriAI/litellm/pull/9719)
## Performance / Loadbalancing / Reliability improvements
1. Preventing Deadlocks
- Reduce DB Deadlocks by storing spend updates in Redis and then committing to DB [PR](https://github.com/BerriAI/litellm/pull/9608)
- Ensure no deadlocks occur when updating DailyUserSpendTransaction [PR](https://github.com/BerriAI/litellm/pull/9690)
- High Traffic fix - ensure new DB + Redis architecture accurately tracks spend [PR](https://github.com/BerriAI/litellm/pull/9673)
- Use Redis for PodLock Manager instead of PG (ensures no deadlocks occur) [PR](https://github.com/BerriAI/litellm/pull/9715)
- v2 DB Deadlock Reduction Architecture Add Max Size for In-Memory Queue + Backpressure Mechanism [PR](https://github.com/BerriAI/litellm/pull/9759)
2. Prisma Migrations [Get Started](../../docs/proxy/prod#9-use-prisma-migrate-deploy)
- connects litellm proxy to litellm's prisma migration files
- Handle db schema updates from new `litellm-proxy-extras` sdk
3. Redis - support password for sync sentinel clients [PR](https://github.com/BerriAI/litellm/pull/9622)
4. Fix "Circular reference detected" error when max_parallel_requests = 0 [PR](https://github.com/BerriAI/litellm/pull/9671)
5. Code QA - Ban hardcoded numbers [PR](https://github.com/BerriAI/litellm/pull/9709)
## Helm
1. fix: wrong indentation of ttlSecondsAfterFinished in chart [PR](https://github.com/BerriAI/litellm/pull/9611)
## General Proxy Improvements
1. Fix - only apply service_account_settings.enforced_params on service accounts [PR](https://github.com/BerriAI/litellm/pull/9683)
2. Fix - handle metadata null on `/chat/completion` [PR](https://github.com/BerriAI/litellm/issues/9717)
3. Fix - Move daily user transaction logging outside of 'disable_spend_logs' flag, as theyre unrelated [PR](https://github.com/BerriAI/litellm/pull/9772)
## Demo
Try this on the demo instance [today](https://docs.litellm.ai/docs/proxy/demo)
## Complete Git Diff
See the complete git diff since v1.65.0-stable, [here](https://github.com/BerriAI/litellm/releases/tag/v1.65.4-stable)

View file

@ -53,7 +53,7 @@ const sidebars = {
{ {
type: "category", type: "category",
label: "Architecture", label: "Architecture",
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch", "proxy/image_handling"], items: ["proxy/architecture", "proxy/db_info", "proxy/db_deadlocks", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch", "proxy/image_handling"],
}, },
{ {
type: "link", type: "link",
@ -188,7 +188,15 @@ const sidebars = {
"providers/azure_ai", "providers/azure_ai",
"providers/aiml", "providers/aiml",
"providers/vertex", "providers/vertex",
"providers/gemini",
{
type: "category",
label: "Google AI Studio",
items: [
"providers/gemini",
"providers/google_ai_studio/files",
]
},
"providers/anthropic", "providers/anthropic",
"providers/aws_sagemaker", "providers/aws_sagemaker",
"providers/bedrock", "providers/bedrock",

Binary file not shown.

View file

@ -0,0 +1,356 @@
datasource client {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client-py"
}
// Budget / Rate Limits for an org
model LiteLLM_BudgetTable {
budget_id String @id @default(uuid())
max_budget Float?
soft_budget Float?
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
model_max_budget Json?
budget_duration String?
budget_reset_at DateTime?
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
}
// Models on proxy
model LiteLLM_CredentialsTable {
credential_id String @id @default(uuid())
credential_name String @unique
credential_values Json
credential_info Json?
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
}
// Models on proxy
model LiteLLM_ProxyModelTable {
model_id String @id @default(uuid())
model_name String
litellm_params Json
model_info Json?
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
}
model LiteLLM_OrganizationTable {
organization_id String @id @default(uuid())
organization_alias String
budget_id String
metadata Json @default("{}")
models String[]
spend Float @default(0.0)
model_spend Json @default("{}")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
teams LiteLLM_TeamTable[]
users LiteLLM_UserTable[]
keys LiteLLM_VerificationToken[]
members LiteLLM_OrganizationMembership[] @relation("OrganizationToMembership")
}
// Model info for teams, just has model aliases for now.
model LiteLLM_ModelTable {
id Int @id @default(autoincrement())
model_aliases Json? @map("aliases")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
team LiteLLM_TeamTable?
}
// Assign prod keys to groups, not individuals
model LiteLLM_TeamTable {
team_id String @id @default(uuid())
team_alias String?
organization_id String?
admins String[]
members String[]
members_with_roles Json @default("{}")
metadata Json @default("{}")
max_budget Float?
spend Float @default(0.0)
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
blocked Boolean @default(false)
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
model_spend Json @default("{}")
model_max_budget Json @default("{}")
model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
}
// Track spend, rate limit, budget Users
model LiteLLM_UserTable {
user_id String @id
user_alias String?
team_id String?
sso_user_id String? @unique
organization_id String?
password String?
teams String[] @default([])
user_role String?
max_budget Float?
spend Float @default(0.0)
user_email String?
models String[]
metadata Json @default("{}")
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
model_spend Json @default("{}")
model_max_budget Json @default("{}")
created_at DateTime? @default(now()) @map("created_at")
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
// relations
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
organization_memberships LiteLLM_OrganizationMembership[]
invitations_created LiteLLM_InvitationLink[] @relation("CreatedBy")
invitations_updated LiteLLM_InvitationLink[] @relation("UpdatedBy")
invitations_user LiteLLM_InvitationLink[] @relation("UserId")
}
// Generate Tokens for Proxy
model LiteLLM_VerificationToken {
token String @id
key_name String?
key_alias String?
soft_budget_cooldown Boolean @default(false) // key-level state on if budget alerts need to be cooled down
spend Float @default(0.0)
expires DateTime?
models String[]
aliases Json @default("{}")
config Json @default("{}")
user_id String?
team_id String?
permissions Json @default("{}")
max_parallel_requests Int?
metadata Json @default("{}")
blocked Boolean?
tpm_limit BigInt?
rpm_limit BigInt?
max_budget Float?
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
model_spend Json @default("{}")
model_max_budget Json @default("{}")
budget_id String?
organization_id String?
created_at DateTime? @default(now()) @map("created_at")
created_by String?
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
updated_by String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
}
model LiteLLM_EndUserTable {
user_id String @id
alias String? // admin-facing alias
spend Float @default(0.0)
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
blocked Boolean @default(false)
}
// store proxy config.yaml
model LiteLLM_Config {
param_name String @id
param_value Json?
}
// View spend, model, api_key per request
model LiteLLM_SpendLogs {
request_id String @id
call_type String
api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken
spend Float @default(0.0)
total_tokens Int @default(0)
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
completionStartTime DateTime? // Assuming completionStartTime is a DateTime field
model String @default("")
model_id String? @default("") // the model id stored in proxy model db
model_group String? @default("") // public model_name / model_group
custom_llm_provider String? @default("") // litellm used custom_llm_provider
api_base String? @default("")
user String? @default("")
metadata Json? @default("{}")
cache_hit String? @default("")
cache_key String? @default("")
request_tags Json? @default("[]")
team_id String?
end_user String?
requester_ip_address String?
messages Json? @default("{}")
response Json? @default("{}")
@@index([startTime])
@@index([end_user])
}
// View spend, model, api_key per request
model LiteLLM_ErrorLogs {
request_id String @id @default(uuid())
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
api_base String @default("")
model_group String @default("") // public model_name / model_group
litellm_model_name String @default("") // model passed to litellm
model_id String @default("") // ID of model in ProxyModelTable
request_kwargs Json @default("{}")
exception_type String @default("")
exception_string String @default("")
status_code String @default("")
}
// Beta - allow team members to request access to a model
model LiteLLM_UserNotifications {
request_id String @id
user_id String
models String[]
justification String
status String // approved, disapproved, pending
}
model LiteLLM_TeamMembership {
// Use this table to track the Internal User's Spend within a Team + Set Budgets, rpm limits for the user within the team
user_id String
team_id String
spend Float @default(0.0)
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
@@id([user_id, team_id])
}
model LiteLLM_OrganizationMembership {
// Use this table to track Internal User and Organization membership. Helps tracking a users role within an Organization
user_id String
organization_id String
user_role String?
spend Float? @default(0.0)
budget_id String?
created_at DateTime? @default(now()) @map("created_at")
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
// relations
user LiteLLM_UserTable @relation(fields: [user_id], references: [user_id])
organization LiteLLM_OrganizationTable @relation("OrganizationToMembership", fields: [organization_id], references: [organization_id])
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
@@id([user_id, organization_id])
@@unique([user_id, organization_id])
}
model LiteLLM_InvitationLink {
// use this table to track invite links sent by admin for people to join the proxy
id String @id @default(uuid())
user_id String
is_accepted Boolean @default(false)
accepted_at DateTime? // when link is claimed (user successfully onboards via link)
expires_at DateTime // till when is link valid
created_at DateTime // when did admin create the link
created_by String // who created the link
updated_at DateTime // when was invite status updated
updated_by String // who updated the status (admin/user who accepted invite)
// Relations
liteLLM_user_table_user LiteLLM_UserTable @relation("UserId", fields: [user_id], references: [user_id])
liteLLM_user_table_created LiteLLM_UserTable @relation("CreatedBy", fields: [created_by], references: [user_id])
liteLLM_user_table_updated LiteLLM_UserTable @relation("UpdatedBy", fields: [updated_by], references: [user_id])
}
model LiteLLM_AuditLog {
id String @id @default(uuid())
updated_at DateTime @default(now())
changed_by String @default("") // user or system that performed the action
changed_by_api_key String @default("") // api key hash that performed the action
action String // create, update, delete
table_name String // on of LitellmTableNames.TEAM_TABLE_NAME, LitellmTableNames.USER_TABLE_NAME, LitellmTableNames.PROXY_MODEL_TABLE_NAME,
object_id String // id of the object being audited. This can be the key id, team id, user id, model id
before_value Json? // value of the row
updated_values Json? // value of the row after change
}
// Track daily user spend metrics per model and key
model LiteLLM_DailyUserSpend {
id String @id @default(uuid())
user_id String
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([user_id, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([user_id])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob {
cronjob_id String @id @default(cuid()) // Unique ID for the record
pod_id String // Unique identifier for the pod acting as the leader
status JobStatus @default(INACTIVE) // Status of the cron job (active or inactive)
last_updated DateTime @default(now()) // Timestamp for the last update of the cron job record
ttl DateTime // Time when the leader's lease expires
}
enum JobStatus {
ACTIVE
INACTIVE
}

View file

@ -30,21 +30,23 @@ class ProxyExtrasDBManager:
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate
for attempt in range(4): for attempt in range(4):
original_dir = os.getcwd() original_dir = os.getcwd()
schema_dir = os.path.dirname(schema_path) migrations_dir = os.path.dirname(__file__)
os.chdir(schema_dir) os.chdir(migrations_dir)
try: try:
if use_migrate: if use_migrate:
logger.info("Running prisma migrate deploy") logger.info("Running prisma migrate deploy")
try: try:
# Set migrations directory for Prisma # Set migrations directory for Prisma
subprocess.run( result = subprocess.run(
["prisma", "migrate", "deploy"], ["prisma", "migrate", "deploy"],
timeout=60, timeout=60,
check=True, check=True,
capture_output=True, capture_output=True,
text=True, text=True,
) )
logger.info(f"prisma migrate deploy stdout: {result.stdout}")
logger.info("prisma migrate deploy completed") logger.info("prisma migrate deploy completed")
return True return True
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@ -77,4 +79,5 @@ class ProxyExtrasDBManager:
time.sleep(random.randrange(5, 15)) time.sleep(random.randrange(5, 15))
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
pass
return False return False

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm-proxy-extras" name = "litellm-proxy-extras"
version = "0.1.2" version = "0.1.3"
description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package." description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package."
authors = ["BerriAI"] authors = ["BerriAI"]
readme = "README.md" readme = "README.md"
@ -22,7 +22,7 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "0.1.2" version = "0.1.3"
version_files = [ version_files = [
"pyproject.toml:version", "pyproject.toml:version",
"../requirements.txt:litellm-proxy-extras==", "../requirements.txt:litellm-proxy-extras==",

View file

@ -56,6 +56,9 @@ from litellm.constants import (
bedrock_embedding_models, bedrock_embedding_models,
known_tokenizer_config, known_tokenizer_config,
BEDROCK_INVOKE_PROVIDERS_LITERAL, BEDROCK_INVOKE_PROVIDERS_LITERAL,
DEFAULT_MAX_TOKENS,
DEFAULT_SOFT_BUDGET,
DEFAULT_ALLOWED_FAILS,
) )
from litellm.types.guardrails import GuardrailItem from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import ( from litellm.proxy._types import (
@ -120,6 +123,7 @@ callbacks: List[
langfuse_default_tags: Optional[List[str]] = None langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[ gcs_pub_sub_use_v1: Optional[
@ -155,7 +159,7 @@ token: Optional[
str str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
telemetry = True telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
modify_params = False modify_params = False
retry = True retry = True
@ -244,7 +248,7 @@ budget_duration: Optional[
str str
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). ] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
default_soft_budget: float = ( default_soft_budget: float = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0 DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
) )
forward_traceparent_to_llm_provider: bool = False forward_traceparent_to_llm_provider: bool = False
@ -796,9 +800,8 @@ from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig
from .llms.github.chat.transformation import GithubChatConfig from .llms.github.chat.transformation import GithubChatConfig
from .llms.empower.chat.transformation import EmpowerChatConfig from .llms.empower.chat.transformation import EmpowerChatConfig
from .llms.huggingface.chat.transformation import ( from .llms.huggingface.chat.transformation import HuggingFaceChatConfig
HuggingfaceChatConfig as HuggingfaceConfig, from .llms.huggingface.embedding.transformation import HuggingFaceEmbeddingConfig
)
from .llms.oobabooga.chat.transformation import OobaboogaConfig from .llms.oobabooga.chat.transformation import OobaboogaConfig
from .llms.maritalk import MaritalkConfig from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig from .llms.openrouter.chat.transformation import OpenrouterConfig

View file

@ -18,6 +18,7 @@ import redis # type: ignore
import redis.asyncio as async_redis # type: ignore import redis.asyncio as async_redis # type: ignore
from litellm import get_secret, get_secret_str from litellm import get_secret, get_secret_str
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
from ._logging import verbose_logger from ._logging import verbose_logger
@ -215,7 +216,7 @@ def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
# Set up the Sentinel client # Set up the Sentinel client
sentinel = redis.Sentinel( sentinel = redis.Sentinel(
sentinel_nodes, sentinel_nodes,
socket_timeout=0.1, socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password, password=sentinel_password,
) )
@ -239,7 +240,7 @@ def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
# Set up the Sentinel client # Set up the Sentinel client
sentinel = async_redis.Sentinel( sentinel = async_redis.Sentinel(
sentinel_nodes, sentinel_nodes,
socket_timeout=0.1, socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password, password=sentinel_password,
) )
@ -319,7 +320,7 @@ def get_redis_connection_pool(**env_overrides):
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs) verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
if "url" in redis_kwargs and redis_kwargs["url"] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.BlockingConnectionPool.from_url( return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"] timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
) )
connection_class = async_redis.Connection connection_class = async_redis.Connection
if "ssl" in redis_kwargs: if "ssl" in redis_kwargs:
@ -327,4 +328,6 @@ def get_redis_connection_pool(**env_overrides):
redis_kwargs.pop("ssl", None) redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None) redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) return async_redis.BlockingConnectionPool(
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
)

View file

@ -124,6 +124,7 @@ class ServiceLogging(CustomLogger):
service=service, service=service,
duration=duration, duration=duration,
call_type=call_type, call_type=call_type,
event_metadata=event_metadata,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
@ -229,6 +230,7 @@ class ServiceLogging(CustomLogger):
service=service, service=service,
duration=duration, duration=duration,
call_type=call_type, call_type=call_type,
event_metadata=event_metadata,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:

View file

@ -14,6 +14,12 @@ import time
from typing import Literal, Optional from typing import Literal, Optional
import litellm import litellm
from litellm.constants import (
DAYS_IN_A_MONTH,
DAYS_IN_A_WEEK,
DAYS_IN_A_YEAR,
HOURS_IN_A_DAY,
)
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
@ -81,11 +87,11 @@ class BudgetManager:
if duration == "daily": if duration == "daily":
duration_in_days = 1 duration_in_days = 1
elif duration == "weekly": elif duration == "weekly":
duration_in_days = 7 duration_in_days = DAYS_IN_A_WEEK
elif duration == "monthly": elif duration == "monthly":
duration_in_days = 28 duration_in_days = DAYS_IN_A_MONTH
elif duration == "yearly": elif duration == "yearly":
duration_in_days = 365 duration_in_days = DAYS_IN_A_YEAR
else: else:
raise ValueError( raise ValueError(
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
@ -182,7 +188,9 @@ class BudgetManager:
current_time = time.time() current_time = time.time()
# Convert duration from days to seconds # Convert duration from days to seconds
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 duration_in_seconds = (
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
)
# Check if duration has elapsed # Check if duration has elapsed
if current_time - last_updated_at >= duration_in_seconds: if current_time - last_updated_at >= duration_in_seconds:

View file

@ -19,6 +19,7 @@ from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.caching import * from litellm.types.caching import *
from litellm.types.utils import all_litellm_params from litellm.types.utils import all_litellm_params
@ -406,7 +407,7 @@ class Cache:
} }
] ]
} }
time.sleep(0.02) time.sleep(CACHED_STREAMING_CHUNK_DELAY)
def _get_cache_logic( def _get_cache_logic(
self, self,

View file

@ -15,7 +15,8 @@ from typing import Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from ..constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache from .base_cache import BaseCache
@ -52,7 +53,8 @@ class InMemoryCache(BaseCache):
# Fast path for common primitive types that are typically small # Fast path for common primitive types that are typically small
if ( if (
isinstance(value, (bool, int, float, str)) isinstance(value, (bool, int, float, str))
and len(str(value)) < self.max_size_per_item * 512 and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate ): # Conservative estimate
return True return True

View file

@ -11,10 +11,12 @@ Has 4 methods:
import ast import ast
import asyncio import asyncio
import json import json
from typing import Any from typing import Any, cast
import litellm import litellm
from litellm._logging import print_verbose from litellm._logging import print_verbose
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache from .base_cache import BaseCache
@ -118,7 +120,11 @@ class QdrantSemanticCache(BaseCache):
} }
elif quantization_config == "scalar": elif quantization_config == "scalar":
quantization_params = { quantization_params = {
"scalar": {"type": "int8", "quantile": 0.99, "always_ram": False} "scalar": {
"type": "int8",
"quantile": QDRANT_SCALAR_QUANTILE,
"always_ram": False,
}
} }
elif quantization_config == "product": elif quantization_config == "product":
quantization_params = { quantization_params = {
@ -132,7 +138,7 @@ class QdrantSemanticCache(BaseCache):
new_collection_status = self.sync_client.put( new_collection_status = self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}", url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={ json={
"vectors": {"size": 1536, "distance": "Cosine"}, "vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
"quantization_config": quantization_params, "quantization_config": quantization_params,
}, },
headers=self.headers, headers=self.headers,
@ -171,10 +177,13 @@ class QdrantSemanticCache(BaseCache):
prompt += message["content"] prompt += message["content"]
# create an embedding for prompt # create an embedding for prompt
embedding_response = litellm.embedding( embedding_response = cast(
model=self.embedding_model, EmbeddingResponse,
input=prompt, litellm.embedding(
cache={"no-store": True, "no-cache": True}, model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
) )
# get the embedding # get the embedding
@ -212,10 +221,13 @@ class QdrantSemanticCache(BaseCache):
prompt += message["content"] prompt += message["content"]
# convert to embedding # convert to embedding
embedding_response = litellm.embedding( embedding_response = cast(
model=self.embedding_model, EmbeddingResponse,
input=prompt, litellm.embedding(
cache={"no-store": True, "no-cache": True}, model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
) )
# get the embedding # get the embedding

View file

@ -304,12 +304,18 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
ttl = self.get_ttl(**kwargs) ttl = self.get_ttl(**kwargs)
nx = kwargs.get("nx", False)
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try: try:
if not hasattr(_redis_client, "set"): if not hasattr(_redis_client, "set"):
raise Exception("Redis client cannot set cache. Attribute not found.") raise Exception("Redis client cannot set cache. Attribute not found.")
await _redis_client.set(name=key, value=json.dumps(value), ex=ttl) result = await _redis_client.set(
name=key,
value=json.dumps(value),
nx=nx,
ex=ttl,
)
print_verbose( print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
) )
@ -326,6 +332,7 @@ class RedisCache(BaseCache):
event_metadata={"key": key}, event_metadata={"key": key},
) )
) )
return result
except Exception as e: except Exception as e:
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
@ -931,7 +938,7 @@ class RedisCache(BaseCache):
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client() _redis_client: Any = self.init_async_client()
# keys is str # keys is str
await _redis_client.delete(key) return await _redis_client.delete(key)
def delete_cache(self, key): def delete_cache(self, key):
self.redis_client.delete(key) self.redis_client.delete(key)

View file

@ -9,6 +9,7 @@ DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute 0.5 # default cooldown a deployment if 50% of requests fail in a given minute
) )
DEFAULT_MAX_TOKENS = 4096 DEFAULT_MAX_TOKENS = 4096
DEFAULT_ALLOWED_FAILS = 3
DEFAULT_REDIS_SYNC_INTERVAL = 1 DEFAULT_REDIS_SYNC_INTERVAL = 1
DEFAULT_COOLDOWN_TIME_SECONDS = 5 DEFAULT_COOLDOWN_TIME_SECONDS = 5
DEFAULT_REPLICATE_POLLING_RETRIES = 5 DEFAULT_REPLICATE_POLLING_RETRIES = 5
@ -16,16 +17,76 @@ DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
DEFAULT_IMAGE_TOKEN_COUNT = 250 DEFAULT_IMAGE_TOKEN_COUNT = 250
DEFAULT_IMAGE_WIDTH = 300 DEFAULT_IMAGE_WIDTH = 300
DEFAULT_IMAGE_HEIGHT = 300 DEFAULT_IMAGE_HEIGHT = 300
DEFAULT_MAX_TOKENS = 256 # used when providers need a default
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic. SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
########### v2 Architecture constants for managing writing updates to the database ###########
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer" REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer" REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100 MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
MAX_SIZE_IN_MEMORY_QUEUE = 10000
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
###############################################################################################
MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
1024 # minimum number of tokens to cache a prompt by Anthropic
)
DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
HOURS_IN_A_DAY = 24
DAYS_IN_A_WEEK = 7
DAYS_IN_A_MONTH = 28
DAYS_IN_A_YEAR = 365
REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
#### TOKEN COUNTING ####
FUNCTION_DEFINITION_TOKEN_COUNT = 9
SYSTEM_MESSAGE_TOKEN_COUNT = 4
TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
MAX_TILE_WIDTH = 512
MAX_TILE_HEIGHT = 512
OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
MIN_NON_ZERO_TEMPERATURE = 0.0001
#### RELIABILITY #### #### RELIABILITY ####
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives. REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
DEFAULT_MAX_LRU_CACHE_SIZE = 16
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0
JITTER = 0.75
DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
AZURE_OPERATION_POLLING_TIMEOUT = 120
REDIS_SOCKET_TIMEOUT = 0.1
REDIS_CONNECTION_POOL_TIMEOUT = 5
NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
MAX_EXCEPTION_MESSAGE_LENGTH = 2000
BEDROCK_MAX_POLICY_SIZE = 75
REPLICATE_POLLING_DELAY_SECONDS = 0.5
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
TOGETHER_AI_4_B = 4
TOGETHER_AI_8_B = 8
TOGETHER_AI_21_B = 21
TOGETHER_AI_41_B = 41
TOGETHER_AI_80_B = 80
TOGETHER_AI_110_B = 110
TOGETHER_AI_EMBEDDING_150_M = 150
TOGETHER_AI_EMBEDDING_350_M = 350
QDRANT_SCALAR_QUANTILE = 0.99
QDRANT_VECTOR_SIZE = 1536
CACHED_STREAMING_CHUNK_DELAY = 0.02
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
#### Networking settings #### #### Networking settings ####
request_timeout: float = 6000 # time in seconds request_timeout: float = 6000 # time in seconds
STREAM_SSE_DONE_STRING: str = "[DONE]" STREAM_SSE_DONE_STRING: str = "[DONE]"
### SPEND TRACKING ###
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
FIREWORKS_AI_56_B_MOE = 56
FIREWORKS_AI_176_B_MOE = 176
FIREWORKS_AI_16_B = 16
FIREWORKS_AI_80_B = 80
LITELLM_CHAT_PROVIDERS = [ LITELLM_CHAT_PROVIDERS = [
"openai", "openai",
@ -426,6 +487,9 @@ MCP_TOOL_NAME_PREFIX = "mcp_tool"
MAX_SPENDLOG_ROWS_TO_QUERY = ( MAX_SPENDLOG_ROWS_TO_QUERY = (
1_000_000 # if spendLogs has more than 1M rows, do not query the DB 1_000_000 # if spendLogs has more than 1M rows, do not query the DB
) )
DEFAULT_SOFT_BUDGET = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0
)
# makes it clear this is a rate limit error for a litellm virtual key # makes it clear this is a rate limit error for a litellm virtual key
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash" RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
@ -451,3 +515,14 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
########################### DB CRON JOB NAMES ########################### ########################### DB CRON JOB NAMES ###########################
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job" DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
PROXY_BATCH_WRITE_AT = 10 # in seconds
DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
DEFAULT_SLACK_ALERTING_THRESHOLD = 300
MAX_TEAM_LIST_LIMIT = 20
DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
LENGTH_OF_LITELLM_GENERATED_KEY = 16
SECRET_MANAGER_REFRESH_INTERVAL = 86400

View file

@ -9,6 +9,10 @@ from pydantic import BaseModel
import litellm import litellm
import litellm._logging import litellm._logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.constants import (
DEFAULT_MAX_LRU_CACHE_SIZE,
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND,
)
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking, StandardBuiltInToolCostTracking,
) )
@ -355,9 +359,7 @@ def cost_per_token( # noqa: PLR0915
def get_replicate_completion_pricing(completion_response: dict, total_time=0.0): def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
# see https://replicate.com/pricing # see https://replicate.com/pricing
# for all litellm currently supported LLMs, almost all requests go to a100_80gb # for all litellm currently supported LLMs, almost all requests go to a100_80gb
a100_80gb_price_per_second_public = ( a100_80gb_price_per_second_public = DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND # assume all calls sent to A100 80GB for now
0.001400 # assume all calls sent to A100 80GB for now
)
if total_time == 0.0: # total time is in ms if total_time == 0.0: # total time is in ms
start_time = completion_response.get("created", time.time()) start_time = completion_response.get("created", time.time())
end_time = getattr(completion_response, "ended", time.time()) end_time = getattr(completion_response, "ended", time.time())
@ -450,7 +452,7 @@ def _select_model_name_for_cost_calc(
return return_model return return_model
@lru_cache(maxsize=16) @lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
def _model_contains_known_llm_provider(model: str) -> bool: def _model_contains_known_llm_provider(model: str) -> bool:
""" """
Check if the model contains a known llm provider Check if the model contains a known llm provider

View file

@ -63,16 +63,17 @@ async def acreate_file(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments call_args = {
func = partial( "file": file,
create_file, "purpose": purpose,
file, "custom_llm_provider": custom_llm_provider,
purpose, "extra_headers": extra_headers,
custom_llm_provider, "extra_body": extra_body,
extra_headers,
extra_body,
**kwargs, **kwargs,
) }
# Use a partial function to pass your keyword arguments
func = partial(create_file, **call_args)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
@ -92,7 +93,7 @@ async def acreate_file(
def create_file( def create_file(
file: FileTypes, file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"], purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -101,6 +102,8 @@ def create_file(
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
Specify either provider_list or custom_llm_provider.
""" """
try: try:
_is_async = kwargs.pop("acreate_file", False) is True _is_async = kwargs.pop("acreate_file", False) is True
@ -120,7 +123,7 @@ def create_file(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout

View file

@ -16,6 +16,7 @@ import litellm.litellm_core_utils.litellm_logging
import litellm.types import litellm.types
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.constants import HOURS_IN_A_DAY
from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.litellm_core_utils.exception_mapping_utils import ( from litellm.litellm_core_utils.exception_mapping_utils import (
@ -649,10 +650,10 @@ class SlackAlerting(CustomBatchLogger):
event_message += ( event_message += (
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`" f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
) )
elif percent_left <= 0.05: elif percent_left <= SLACK_ALERTING_THRESHOLD_5_PERCENT:
event = "threshold_crossed" event = "threshold_crossed"
event_message += "5% Threshold Crossed " event_message += "5% Threshold Crossed "
elif percent_left <= 0.15: elif percent_left <= SLACK_ALERTING_THRESHOLD_15_PERCENT:
event = "threshold_crossed" event = "threshold_crossed"
event_message += "15% Threshold Crossed" event_message += "15% Threshold Crossed"
elif user_info.soft_budget is not None: elif user_info.soft_budget is not None:
@ -1718,7 +1719,7 @@ Model Info:
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
key=_event_cache_key, key=_event_cache_key,
value="SENT", value="SENT",
ttl=(30 * 24 * 60 * 60), # 1 month ttl=(30 * HOURS_IN_A_DAY * 60 * 60), # 1 month
) )
except Exception as e: except Exception as e:

View file

@ -41,7 +41,7 @@ from litellm.types.utils import StandardLoggingPayload
from ..additional_logging_utils import AdditionalLoggingUtils from ..additional_logging_utils import AdditionalLoggingUtils
# max number of logs DD API can accept # max number of logs DD API can accept
DD_MAX_BATCH_SIZE = 1000
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types) # specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
DD_LOGGED_SUCCESS_SERVICE_TYPES = [ DD_LOGGED_SUCCESS_SERVICE_TYPES = [

View file

@ -20,10 +20,6 @@ else:
VertexBase = Any VertexBase = Any
GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None: def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
@ -125,6 +121,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs kwargs
) )
headers = await self.construct_request_headers( headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"], vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"], service_account_json=gcs_logging_config["path_service_account"],

View file

@ -818,7 +818,7 @@ class PrometheusLogger(CustomLogger):
requested_model=request_data.get("model", ""), requested_model=request_data.get("model", ""),
status_code=str(getattr(original_exception, "status_code", None)), status_code=str(getattr(original_exception, "status_code", None)),
exception_status=str(getattr(original_exception, "status_code", None)), exception_status=str(getattr(original_exception, "status_code", None)),
exception_class=str(original_exception.__class__.__name__), exception_class=self._get_exception_class_name(original_exception),
tags=_tags, tags=_tags,
) )
_labels = prometheus_label_factory( _labels = prometheus_label_factory(
@ -917,7 +917,7 @@ class PrometheusLogger(CustomLogger):
api_base=api_base, api_base=api_base,
api_provider=llm_provider, api_provider=llm_provider,
exception_status=str(getattr(exception, "status_code", None)), exception_status=str(getattr(exception, "status_code", None)),
exception_class=exception.__class__.__name__, exception_class=self._get_exception_class_name(exception),
requested_model=model_group, requested_model=model_group,
hashed_api_key=standard_logging_payload["metadata"][ hashed_api_key=standard_logging_payload["metadata"][
"user_api_key_hash" "user_api_key_hash"
@ -1146,6 +1146,22 @@ class PrometheusLogger(CustomLogger):
) )
return return
@staticmethod
def _get_exception_class_name(exception: Exception) -> str:
exception_class_name = ""
if hasattr(exception, "llm_provider"):
exception_class_name = getattr(exception, "llm_provider") or ""
# pretty print the provider name on prometheus
# eg. `openai` -> `Openai.`
if len(exception_class_name) >= 1:
exception_class_name = (
exception_class_name[0].upper() + exception_class_name[1:] + "."
)
exception_class_name += exception.__class__.__name__
return exception_class_name
async def log_success_fallback_event( async def log_success_fallback_event(
self, original_model_group: str, kwargs: dict, original_exception: Exception self, original_model_group: str, kwargs: dict, original_exception: Exception
): ):
@ -1181,7 +1197,7 @@ class PrometheusLogger(CustomLogger):
team=standard_metadata["user_api_key_team_id"], team=standard_metadata["user_api_key_team_id"],
team_alias=standard_metadata["user_api_key_team_alias"], team_alias=standard_metadata["user_api_key_team_alias"],
exception_status=str(getattr(original_exception, "status_code", None)), exception_status=str(getattr(original_exception, "status_code", None)),
exception_class=str(original_exception.__class__.__name__), exception_class=self._get_exception_class_name(original_exception),
tags=_tags, tags=_tags,
) )
_labels = prometheus_label_factory( _labels = prometheus_label_factory(
@ -1225,7 +1241,7 @@ class PrometheusLogger(CustomLogger):
team=standard_metadata["user_api_key_team_id"], team=standard_metadata["user_api_key_team_id"],
team_alias=standard_metadata["user_api_key_team_alias"], team_alias=standard_metadata["user_api_key_team_alias"],
exception_status=str(getattr(original_exception, "status_code", None)), exception_status=str(getattr(original_exception, "status_code", None)),
exception_class=str(original_exception.__class__.__name__), exception_class=self._get_exception_class_name(original_exception),
tags=_tags, tags=_tags,
) )
@ -1721,6 +1737,36 @@ class PrometheusLogger(CustomLogger):
return (end_time - start_time).total_seconds() return (end_time - start_time).total_seconds()
return None return None
@staticmethod
def _mount_metrics_endpoint(premium_user: bool):
"""
Mount the Prometheus metrics endpoint with optional authentication.
Args:
premium_user (bool): Whether the user is a premium user
require_auth (bool, optional): Whether to require authentication for the metrics endpoint.
Defaults to False.
"""
from prometheus_client import make_asgi_app
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import app
if premium_user is not True:
verbose_proxy_logger.warning(
f"Prometheus metrics are only available for premium users. {CommonProxyErrors.not_premium_user.value}"
)
# Create metrics ASGI app
metrics_app = make_asgi_app()
# Mount the metrics app to the app
app.mount("/metrics", metrics_app)
verbose_proxy_logger.debug(
"Starting Prometheus Metrics on /metrics (no authentication)"
)
def prometheus_label_factory( def prometheus_label_factory(
supported_enum_labels: List[str], supported_enum_labels: List[str],

View file

@ -3,11 +3,16 @@
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers) # On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.types.integrations.prometheus import LATENCY_BUCKETS from litellm.types.integrations.prometheus import LATENCY_BUCKETS
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import (
DEFAULT_SERVICE_CONFIGS,
ServiceLoggerPayload,
ServiceMetrics,
ServiceTypes,
)
FAILED_REQUESTS_LABELS = ["error_class", "function_name"] FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
@ -23,7 +28,8 @@ class PrometheusServicesLogger:
): ):
try: try:
try: try:
from prometheus_client import REGISTRY, Counter, Histogram from prometheus_client import REGISTRY, Counter, Gauge, Histogram
from prometheus_client.gc_collector import Collector
except ImportError: except ImportError:
raise Exception( raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`" "Missing prometheus_client. Run `pip install prometheus-client`"
@ -31,36 +37,51 @@ class PrometheusServicesLogger:
self.Histogram = Histogram self.Histogram = Histogram
self.Counter = Counter self.Counter = Counter
self.Gauge = Gauge
self.REGISTRY = REGISTRY self.REGISTRY = REGISTRY
verbose_logger.debug("in init prometheus services metrics") verbose_logger.debug("in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes] self.payload_to_prometheus_map: Dict[
str, List[Union[Histogram, Counter, Gauge, Collector]]
] = {}
self.payload_to_prometheus_map = ( for service in ServiceTypes:
{} service_metrics: List[Union[Histogram, Counter, Gauge, Collector]] = []
) # store the prometheus histogram/counter we need to call for each field in payload
for service in self.services: metrics_to_initialize = self._get_service_metrics_initialize(service)
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service,
type_of_request="failed_requests",
additional_labels=FAILED_REQUESTS_LABELS,
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]
self.prometheus_to_amount_map: dict = ( # Initialize only the configured metrics for each service
{} if ServiceMetrics.HISTOGRAM in metrics_to_initialize:
) # the field / value in ServiceLoggerPayload the object needs to be incremented by histogram = self.create_histogram(
service.value, type_of_request="latency"
)
if histogram:
service_metrics.append(histogram)
if ServiceMetrics.COUNTER in metrics_to_initialize:
counter_failed_request = self.create_counter(
service.value,
type_of_request="failed_requests",
additional_labels=FAILED_REQUESTS_LABELS,
)
if counter_failed_request:
service_metrics.append(counter_failed_request)
counter_total_requests = self.create_counter(
service.value, type_of_request="total_requests"
)
if counter_total_requests:
service_metrics.append(counter_total_requests)
if ServiceMetrics.GAUGE in metrics_to_initialize:
gauge = self.create_gauge(service.value, type_of_request="size")
if gauge:
service_metrics.append(gauge)
if service_metrics:
self.payload_to_prometheus_map[service.value] = service_metrics
self.prometheus_to_amount_map: dict = {}
### MOCK TESTING ### ### MOCK TESTING ###
self.mock_testing = mock_testing self.mock_testing = mock_testing
self.mock_testing_success_calls = 0 self.mock_testing_success_calls = 0
@ -70,6 +91,19 @@ class PrometheusServicesLogger:
print_verbose(f"Got exception on init prometheus client {str(e)}") print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e raise e
def _get_service_metrics_initialize(
self, service: ServiceTypes
) -> List[ServiceMetrics]:
DEFAULT_METRICS = [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
if service not in DEFAULT_SERVICE_CONFIGS:
return DEFAULT_METRICS
metrics = DEFAULT_SERVICE_CONFIGS.get(service, {}).get("metrics", [])
if not metrics:
verbose_logger.debug(f"No metrics found for service {service}")
return DEFAULT_METRICS
return metrics
def is_metric_registered(self, metric_name) -> bool: def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect(): for metric in self.REGISTRY.collect():
if metric_name == metric.name: if metric_name == metric.name:
@ -94,6 +128,15 @@ class PrometheusServicesLogger:
buckets=LATENCY_BUCKETS, buckets=LATENCY_BUCKETS,
) )
def create_gauge(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self._get_metric(metric_name)
return self.Gauge(
metric_name, "Gauge for {} service".format(service), labelnames=[service]
)
def create_counter( def create_counter(
self, self,
service: str, service: str,
@ -120,6 +163,15 @@ class PrometheusServicesLogger:
histogram.labels(labels).observe(amount) histogram.labels(labels).observe(amount)
def update_gauge(
self,
gauge,
labels: str,
amount: float,
):
assert isinstance(gauge, self.Gauge)
gauge.labels(labels).set(amount)
def increment_counter( def increment_counter(
self, self,
counter, counter,
@ -190,6 +242,13 @@ class PrometheusServicesLogger:
labels=payload.service.value, labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
) )
elif isinstance(obj, self.Gauge):
if payload.event_metadata:
self.update_gauge(
gauge=obj,
labels=payload.event_metadata.get("gauge_labels") or "",
amount=payload.event_metadata.get("gauge_value") or 0,
)
async def async_service_failure_hook( async def async_service_failure_hook(
self, self,

View file

@ -10,6 +10,7 @@ class CredentialAccessor:
@staticmethod @staticmethod
def get_credential_values(credential_name: str) -> dict: def get_credential_values(credential_name: str) -> dict:
"""Safe accessor for credentials.""" """Safe accessor for credentials."""
if not litellm.credential_list: if not litellm.credential_list:
return {} return {}
for credential in litellm.credential_list: for credential in litellm.credential_list:

View file

@ -3,6 +3,7 @@ from typing import Optional, Tuple
import httpx import httpx
import litellm import litellm
from litellm.constants import REPLICATE_MODEL_NAME_WITH_ID_LENGTH
from litellm.secret_managers.main import get_secret, get_secret_str from litellm.secret_managers.main import get_secret, get_secret_str
from ..types.router import LiteLLM_Params from ..types.router import LiteLLM_Params
@ -256,10 +257,13 @@ def get_llm_provider( # noqa: PLR0915
elif model in litellm.cohere_chat_models: elif model in litellm.cohere_chat_models:
custom_llm_provider = "cohere_chat" custom_llm_provider = "cohere_chat"
## replicate ## replicate
elif model in litellm.replicate_models or (":" in model and len(model) > 64): elif model in litellm.replicate_models or (
":" in model and len(model) > REPLICATE_MODEL_NAME_WITH_ID_LENGTH
):
model_parts = model.split(":") model_parts = model.split(":")
if ( if (
len(model_parts) > 1 and len(model_parts[1]) == 64 len(model_parts) > 1
and len(model_parts[1]) == REPLICATE_MODEL_NAME_WITH_ID_LENGTH
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
custom_llm_provider = "replicate" custom_llm_provider = "replicate"
elif model in litellm.replicate_models: elif model in litellm.replicate_models:

View file

@ -120,7 +120,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
return litellm.ReplicateConfig().get_supported_openai_params(model=model) return litellm.ReplicateConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params(model=model) return litellm.HuggingFaceChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "jina_ai": elif custom_llm_provider == "jina_ai":
if request_type == "embeddings": if request_type == "embeddings":
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params() return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()

View file

@ -28,6 +28,10 @@ from litellm._logging import _is_debugging_on, verbose_logger
from litellm.batches.batch_utils import _handle_completed_batch from litellm.batches.batch_utils import _handle_completed_batch
from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching import DualCache, InMemoryCache
from litellm.caching.caching_handler import LLMCachingHandler from litellm.caching.caching_handler import LLMCachingHandler
from litellm.constants import (
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
)
from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.arize.arize import ArizeLogger from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
@ -453,8 +457,12 @@ class Logging(LiteLLMLoggingBaseClass):
non_default_params: dict, non_default_params: dict,
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
prompt_management_logger: Optional[CustomLogger] = None,
) -> Tuple[str, List[AllMessageValues], dict]: ) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = self.get_custom_logger_for_prompt_management(model) custom_logger = (
prompt_management_logger
or self.get_custom_logger_for_prompt_management(model)
)
if custom_logger: if custom_logger:
( (
model, model,
@ -3745,9 +3753,12 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
response_cost=response_cost, response_cost=response_cost,
response_cost_failure_debug_info=None, response_cost_failure_debug_info=None,
status=str("success"), status=str("success"),
total_tokens=int(30), total_tokens=int(
prompt_tokens=int(20), DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT
completion_tokens=int(10), + DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT
),
prompt_tokens=int(DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT),
completion_tokens=int(DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT),
startTime=start_time, startTime=start_time,
endTime=end_time, endTime=end_time,
completionStartTime=completion_start_time, completionStartTime=completion_start_time,

View file

@ -5,6 +5,7 @@ Helper utilities for tracking the cost of built-in tools.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import litellm import litellm
from litellm.constants import OPENAI_FILE_SEARCH_COST_PER_1K_CALLS
from litellm.types.llms.openai import FileSearchTool, WebSearchOptions from litellm.types.llms.openai import FileSearchTool, WebSearchOptions
from litellm.types.utils import ( from litellm.types.utils import (
ModelInfo, ModelInfo,
@ -132,7 +133,7 @@ class StandardBuiltInToolCostTracking:
""" """
if file_search is None: if file_search is None:
return 0.0 return 0.0
return 2.5 / 1000 return OPENAI_FILE_SEARCH_COST_PER_1K_CALLS
@staticmethod @staticmethod
def chat_completion_response_includes_annotations( def chat_completion_response_includes_annotations(

View file

@ -9,6 +9,7 @@ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.types.llms.databricks import DatabricksTool
from litellm.types.llms.openai import ChatCompletionThinkingBlock from litellm.types.llms.openai import ChatCompletionThinkingBlock
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionDeltaToolCall, ChatCompletionDeltaToolCall,
@ -35,6 +36,25 @@ from litellm.types.utils import (
from .get_headers import get_response_headers from .get_headers import get_response_headers
def convert_tool_call_to_json_mode(
tool_calls: List[ChatCompletionMessageToolCall],
convert_tool_call_to_json_mode: bool,
) -> Tuple[Optional[Message], Optional[str]]:
if _should_convert_tool_call_to_json_mode(
tool_calls=tool_calls,
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
):
# to support 'json_schema' logic on older models
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop"
return message, finish_reason
return None, None
async def convert_to_streaming_response_async(response_object: Optional[dict] = None): async def convert_to_streaming_response_async(response_object: Optional[dict] = None):
""" """
Asynchronously converts a response object to a streaming response. Asynchronously converts a response object to a streaming response.
@ -335,21 +355,14 @@ class LiteLLMResponseObjectHandler:
Only supported for HF TGI models Only supported for HF TGI models
""" """
transformed_logprobs: Optional[TextCompletionLogprobs] = None transformed_logprobs: Optional[TextCompletionLogprobs] = None
if custom_llm_provider == "huggingface":
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.huggingface._transform_logprobs(
hf_response=raw_response
)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
return transformed_logprobs return transformed_logprobs
def _should_convert_tool_call_to_json_mode( def _should_convert_tool_call_to_json_mode(
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None, tool_calls: Optional[
Union[List[ChatCompletionMessageToolCall], List[DatabricksTool]]
] = None,
convert_tool_call_to_json_mode: Optional[bool] = None, convert_tool_call_to_json_mode: Optional[bool] = None,
) -> bool: ) -> bool:
""" """

View file

@ -7,6 +7,7 @@ from typing import Dict, List, Literal, Optional, Union, cast
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionAssistantMessage, ChatCompletionAssistantMessage,
ChatCompletionFileObject,
ChatCompletionUserMessage, ChatCompletionUserMessage,
) )
from litellm.types.utils import Choices, ModelResponse, StreamingChoices from litellm.types.utils import Choices, ModelResponse, StreamingChoices
@ -34,7 +35,7 @@ def handle_messages_with_content_list_to_str_conversion(
def strip_name_from_messages( def strip_name_from_messages(
messages: List[AllMessageValues], messages: List[AllMessageValues], allowed_name_roles: List[str] = ["user"]
) -> List[AllMessageValues]: ) -> List[AllMessageValues]:
""" """
Removes 'name' from messages Removes 'name' from messages
@ -43,7 +44,7 @@ def strip_name_from_messages(
for message in messages: for message in messages:
msg_role = message.get("role") msg_role = message.get("role")
msg_copy = message.copy() msg_copy = message.copy()
if msg_role == "user": if msg_role not in allowed_name_roles:
msg_copy.pop("name", None) # type: ignore msg_copy.pop("name", None) # type: ignore
new_messages.append(msg_copy) new_messages.append(msg_copy)
return new_messages return new_messages
@ -292,3 +293,58 @@ def get_completion_messages(
messages, assistant_continue_message, ensure_alternating_roles messages, assistant_continue_message, ensure_alternating_roles
) )
return messages return messages
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
def update_messages_with_model_file_ids(
messages: List[AllMessageValues],
model_id: str,
model_file_id_mapping: Dict[str, Dict[str, str]],
) -> List[AllMessageValues]:
"""
Updates messages with model file ids.
model_file_id_mapping: Dict[str, Dict[str, str]] = {
"litellm_proxy/file_id": {
"model_id": "provider_file_id"
}
}
"""
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
provider_file_id = (
model_file_id_mapping.get(file_id, {}).get(model_id)
or file_id
)
file_object_file_field["file_id"] = provider_file_id
return messages

View file

@ -1300,20 +1300,37 @@ def convert_to_anthropic_tool_invoke(
] ]
} }
""" """
anthropic_tool_invoke = [ anthropic_tool_invoke = []
AnthropicMessagesToolUseParam(
for tool in tool_calls:
if not get_attribute_or_key(tool, "type") == "function":
continue
_anthropic_tool_use_param = AnthropicMessagesToolUseParam(
type="tool_use", type="tool_use",
id=get_attribute_or_key(tool, "id"), id=cast(str, get_attribute_or_key(tool, "id")),
name=get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"), name=cast(
str,
get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"),
),
input=json.loads( input=json.loads(
get_attribute_or_key( get_attribute_or_key(
get_attribute_or_key(tool, "function"), "arguments" get_attribute_or_key(tool, "function"), "arguments"
) )
), ),
) )
for tool in tool_calls
if get_attribute_or_key(tool, "type") == "function" _content_element = add_cache_control_to_content(
] anthropic_content_element=_anthropic_tool_use_param,
orignal_content_element=dict(tool),
)
if "cache_control" in _content_element:
_anthropic_tool_use_param["cache_control"] = _content_element[
"cache_control"
]
anthropic_tool_invoke.append(_anthropic_tool_use_param)
return anthropic_tool_invoke return anthropic_tool_invoke
@ -1324,6 +1341,7 @@ def add_cache_control_to_content(
AnthropicMessagesImageParam, AnthropicMessagesImageParam,
AnthropicMessagesTextParam, AnthropicMessagesTextParam,
AnthropicMessagesDocumentParam, AnthropicMessagesDocumentParam,
AnthropicMessagesToolUseParam,
ChatCompletionThinkingBlock, ChatCompletionThinkingBlock,
], ],
orignal_content_element: Union[dict, AllMessageValues], orignal_content_element: Union[dict, AllMessageValues],

View file

@ -1,6 +1,6 @@
import base64 import base64
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union, cast
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionAssistantContentValue, ChatCompletionAssistantContentValue,
@ -9,7 +9,9 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionAudioResponse, ChatCompletionAudioResponse,
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
Choices,
CompletionTokensDetails, CompletionTokensDetails,
CompletionTokensDetailsWrapper,
Function, Function,
FunctionCall, FunctionCall,
ModelResponse, ModelResponse,
@ -203,14 +205,14 @@ class ChunkProcessor:
) )
def get_combined_content( def get_combined_content(
self, chunks: List[Dict[str, Any]] self, chunks: List[Dict[str, Any]], delta_key: str = "content"
) -> ChatCompletionAssistantContentValue: ) -> ChatCompletionAssistantContentValue:
content_list: List[str] = [] content_list: List[str] = []
for chunk in chunks: for chunk in chunks:
choices = chunk["choices"] choices = chunk["choices"]
for choice in choices: for choice in choices:
delta = choice.get("delta", {}) delta = choice.get("delta", {})
content = delta.get("content", "") content = delta.get(delta_key, "")
if content is None: if content is None:
continue # openai v1.0.0 sets content = None for chunks continue # openai v1.0.0 sets content = None for chunks
content_list.append(content) content_list.append(content)
@ -221,6 +223,11 @@ class ChunkProcessor:
# Update the "content" field within the response dictionary # Update the "content" field within the response dictionary
return combined_content return combined_content
def get_combined_reasoning_content(
self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAssistantContentValue:
return self.get_combined_content(chunks, delta_key="reasoning_content")
def get_combined_audio_content( def get_combined_audio_content(
self, chunks: List[Dict[str, Any]] self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAudioResponse: ) -> ChatCompletionAudioResponse:
@ -296,12 +303,27 @@ class ChunkProcessor:
"prompt_tokens_details": prompt_tokens_details, "prompt_tokens_details": prompt_tokens_details,
} }
def count_reasoning_tokens(self, response: ModelResponse) -> int:
reasoning_tokens = 0
for choice in response.choices:
if (
hasattr(cast(Choices, choice).message, "reasoning_content")
and cast(Choices, choice).message.reasoning_content is not None
):
reasoning_tokens += token_counter(
text=cast(Choices, choice).message.reasoning_content,
count_response_tokens=True,
)
return reasoning_tokens
def calculate_usage( def calculate_usage(
self, self,
chunks: List[Union[Dict[str, Any], ModelResponse]], chunks: List[Union[Dict[str, Any], ModelResponse]],
model: str, model: str,
completion_output: str, completion_output: str,
messages: Optional[List] = None, messages: Optional[List] = None,
reasoning_tokens: Optional[int] = None,
) -> Usage: ) -> Usage:
""" """
Calculate usage for the given chunks. Calculate usage for the given chunks.
@ -382,6 +404,19 @@ class ChunkProcessor:
) # for anthropic ) # for anthropic
if completion_tokens_details is not None: if completion_tokens_details is not None:
returned_usage.completion_tokens_details = completion_tokens_details returned_usage.completion_tokens_details = completion_tokens_details
if reasoning_tokens is not None:
if returned_usage.completion_tokens_details is None:
returned_usage.completion_tokens_details = (
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
)
elif (
returned_usage.completion_tokens_details is not None
and returned_usage.completion_tokens_details.reasoning_tokens is None
):
returned_usage.completion_tokens_details.reasoning_tokens = (
reasoning_tokens
)
if prompt_tokens_details is not None: if prompt_tokens_details is not None:
returned_usage.prompt_tokens_details = prompt_tokens_details returned_usage.prompt_tokens_details = prompt_tokens_details

View file

@ -214,10 +214,7 @@ class CustomStreamWrapper:
Output parse <s> / </s> special tokens for sagemaker + hf streaming. Output parse <s> / </s> special tokens for sagemaker + hf streaming.
""" """
hold = False hold = False
if ( if self.custom_llm_provider != "sagemaker":
self.custom_llm_provider != "huggingface"
and self.custom_llm_provider != "sagemaker"
):
return hold, chunk return hold, chunk
if finish_reason: if finish_reason:
@ -290,49 +287,6 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
raise e raise e
def handle_huggingface_chunk(self, chunk):
try:
if not isinstance(chunk, str):
chunk = chunk.decode(
"utf-8"
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
text = ""
is_finished = False
finish_reason = ""
print_verbose(f"chunk: {chunk}")
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
if data_json.get("details", False) and data_json["details"].get(
"finish_reason", False
):
is_finished = True
finish_reason = data_json["details"]["finish_reason"]
elif data_json.get(
"generated_text", False
): # if full generated text exists, then stream is complete
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in chunk:
raise ValueError(chunk)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
raise e
def handle_ai21_chunk(self, chunk): # fake streaming def handle_ai21_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
data_json = json.loads(chunk) data_json = json.loads(chunk)
@ -1049,11 +1003,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "predibase": elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
response_obj = self.handle_predibase_chunk(chunk) response_obj = self.handle_predibase_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -11,6 +11,10 @@ from litellm.constants import (
DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_HEIGHT,
DEFAULT_IMAGE_TOKEN_COUNT, DEFAULT_IMAGE_TOKEN_COUNT,
DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_WIDTH,
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES,
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES,
MAX_TILE_HEIGHT,
MAX_TILE_WIDTH,
) )
from litellm.llms.custom_httpx.http_handler import _get_httpx_client from litellm.llms.custom_httpx.http_handler import _get_httpx_client
@ -97,11 +101,14 @@ def resize_image_high_res(
height: int, height: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
# Maximum dimensions for high res mode # Maximum dimensions for high res mode
max_short_side = 768 max_short_side = MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
max_long_side = 2000 max_long_side = MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES
# Return early if no resizing is needed # Return early if no resizing is needed
if width <= 768 and height <= 768: if (
width <= MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
and height <= MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
):
return width, height return width, height
# Determine the longer and shorter sides # Determine the longer and shorter sides
@ -132,7 +139,10 @@ def resize_image_high_res(
# Test the function with the given example # Test the function with the given example
def calculate_tiles_needed( def calculate_tiles_needed(
resized_width, resized_height, tile_width=512, tile_height=512 resized_width,
resized_height,
tile_width=MAX_TILE_WIDTH,
tile_height=MAX_TILE_HEIGHT,
): ):
tiles_across = (resized_width + tile_width - 1) // tile_width tiles_across = (resized_width + tile_width - 1) // tile_width
tiles_down = (resized_height + tile_height - 1) // tile_height tiles_down = (resized_height + tile_height - 1) // tile_height

View file

@ -21,7 +21,6 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicChatCompletionUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentBlockStop, ContentBlockStop,
@ -32,13 +31,13 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionThinkingBlock, ChatCompletionThinkingBlock,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
) )
from litellm.types.utils import ( from litellm.types.utils import (
Delta, Delta,
GenericStreamingChunk, GenericStreamingChunk,
ModelResponseStream, ModelResponseStream,
StreamingChoices, StreamingChoices,
Usage,
) )
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
@ -487,10 +486,8 @@ class ModelResponseIterator:
return True return True
return False return False
def _handle_usage( def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
self, anthropic_usage_chunk: Union[dict, UsageDelta] usage_block = Usage(
) -> AnthropicChatCompletionUsageBlock:
usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
total_tokens=anthropic_usage_chunk.get("input_tokens", 0) total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
@ -581,7 +578,7 @@ class ModelResponseIterator:
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[Usage] = None
provider_specific_fields: Dict[str, Any] = {} provider_specific_fields: Dict[str, Any] = {}
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None

View file

@ -5,7 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx import httpx
import litellm import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import (
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS,
RESPONSE_FORMAT_TOOL_NAME,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
from litellm.llms.base_llm.base_utils import type_to_response_format_param from litellm.llms.base_llm.base_utils import type_to_response_format_param
@ -30,9 +33,16 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolParam, ChatCompletionToolParam,
) )
from litellm.types.utils import CompletionTokensDetailsWrapper
from litellm.types.utils import Message as LitellmMessage from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import ModelResponse, Usage, add_dummy_tool, has_tool_call_blocks from litellm.utils import (
ModelResponse,
Usage,
add_dummy_tool,
has_tool_call_blocks,
token_counter,
)
from ..common_utils import AnthropicError, process_anthropic_headers from ..common_utils import AnthropicError, process_anthropic_headers
@ -53,7 +63,7 @@ class AnthropicConfig(BaseConfig):
max_tokens: Optional[ max_tokens: Optional[
int int
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) ] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None
@ -65,7 +75,7 @@ class AnthropicConfig(BaseConfig):
self, self,
max_tokens: Optional[ max_tokens: Optional[
int int
] = 4096, # You can pass in a value yourself or use the default value 4096 ] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS, # You can pass in a value yourself or use the default value 4096
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
@ -309,6 +319,33 @@ class AnthropicConfig(BaseConfig):
else: else:
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}") raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
def map_response_format_to_anthropic_tool(
self, value: Optional[dict], optional_params: dict, is_thinking_enabled: bool
) -> Optional[AnthropicMessagesTool]:
ignore_response_format_types = ["text"]
if (
value is None or value["type"] in ignore_response_format_types
): # value is a no-op
return None
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool = self._create_json_tool_call_for_response_format(
json_schema=json_schema,
)
return _tool
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -352,34 +389,18 @@ class AnthropicConfig(BaseConfig):
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict): if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"] _tool = self.map_response_format_to_anthropic_tool(
if value["type"] in ignore_response_format_types: # value is a no-op value, optional_params, is_thinking_enabled
)
if _tool is None:
continue continue
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if not is_thinking_enabled: if not is_thinking_enabled:
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"} _tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
_tool = self._create_json_tool_call_for_response_format(
json_schema=json_schema,
)
optional_params = self._add_tools_to_optional_params( optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool] optional_params=optional_params, tools=[_tool]
) )
optional_params["json_mode"] = True
if param == "user": if param == "user":
optional_params["metadata"] = {"user_id": value} optional_params["metadata"] = {"user_id": value}
if param == "thinking": if param == "thinking":
@ -769,6 +790,15 @@ class AnthropicConfig(BaseConfig):
prompt_tokens_details = PromptTokensDetailsWrapper( prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens cached_tokens=cache_read_input_tokens
) )
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
@ -777,6 +807,7 @@ class AnthropicConfig(BaseConfig):
prompt_tokens_details=prompt_tokens_details, prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens, cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens, cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
) )
setattr(model_response, "usage", usage) # type: ignore setattr(model_response, "usage", usage) # type: ignore

View file

@ -11,6 +11,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.constants import DEFAULT_MAX_TOKENS
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt, custom_prompt,
prompt_factory, prompt_factory,
@ -65,7 +66,9 @@ class AnthropicTextConfig(BaseConfig):
def __init__( def __init__(
self, self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default max_tokens_to_sample: Optional[
int
] = DEFAULT_MAX_TOKENS, # anthropic requires a default
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,

View file

@ -7,7 +7,7 @@ import httpx # type: ignore
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
import litellm import litellm
from litellm.constants import DEFAULT_MAX_RETRIES from litellm.constants import AZURE_OPERATION_POLLING_TIMEOUT, DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
@ -857,7 +857,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
await response.aread() await response.aread()
timeout_secs: int = 120 timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
start_time = time.time() start_time = time.time()
if "status" not in response.json(): if "status" not in response.json():
raise Exception( raise Exception(
@ -955,7 +955,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
response.read() response.read()
timeout_secs: int = 120 timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
start_time = time.time() start_time = time.time()
if "status" not in response.json(): if "status" not in response.json():
raise Exception( raise Exception(

View file

@ -7,6 +7,10 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages, convert_to_azure_openai_messages,
) )
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.azure import (
API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT,
API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT,
)
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema from litellm.utils import supports_response_schema
@ -123,7 +127,10 @@ class AzureOpenAIConfig(BaseConfig):
- check if api_version is supported for response_format - check if api_version is supported for response_format
""" """
is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8 is_supported = (
int(api_version_year) <= API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT
and int(api_version_month) >= API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT
)
return is_supported return is_supported

View file

@ -14,6 +14,7 @@ Translations handled by LiteLLM:
from typing import List, Optional from typing import List, Optional
import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.utils import get_model_info from litellm.utils import get_model_info
@ -22,6 +23,27 @@ from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
class AzureOpenAIO1Config(OpenAIOSeriesConfig): class AzureOpenAIO1Config(OpenAIOSeriesConfig):
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the Azure O-Series models
"""
all_openai_params = litellm.OpenAIGPTConfig().get_supported_openai_params(
model=model
)
non_supported_params = [
"logprobs",
"top_p",
"presence_penalty",
"frequency_penalty",
"top_logprobs",
]
o_series_only_param = ["reasoning_effort"]
all_openai_params.extend(o_series_only_param)
return [
param for param in all_openai_params if param not in non_supported_params
]
def should_fake_stream( def should_fake_stream(
self, self,
model: Optional[str], model: Optional[str],

View file

@ -28,11 +28,11 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
self, self,
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
openai_client: AsyncAzureOpenAI, openai_client: AsyncAzureOpenAI,
) -> FileObject: ) -> OpenAIFileObject:
verbose_logger.debug("create_file_data=%s", create_file_data) verbose_logger.debug("create_file_data=%s", create_file_data)
response = await openai_client.files.create(**create_file_data) response = await openai_client.files.create(**create_file_data)
verbose_logger.debug("create_file_response=%s", response) verbose_logger.debug("create_file_response=%s", response)
return response return OpenAIFileObject(**response.model_dump())
def create_file( def create_file(
self, self,
@ -66,7 +66,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
raise ValueError( raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
) )
return self.acreate_file( # type: ignore return self.acreate_file(
create_file_data=create_file_data, openai_client=openai_client create_file_data=create_file_data, openai_client=openai_client
) )
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data) response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)

View file

@ -2,6 +2,7 @@ import json
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, Union from typing import Optional, Union
import litellm
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
@ -33,6 +34,18 @@ class BaseModelResponseIterator:
self, str_line: str self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]: ) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point # chunk is a str at this point
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
str_line
)
try:
if stripped_chunk is not None:
stripped_json_chunk: Optional[dict] = json.loads(stripped_chunk)
else:
stripped_json_chunk = None
except json.JSONDecodeError:
stripped_json_chunk = None
if "[DONE]" in str_line: if "[DONE]" in str_line:
return GenericStreamingChunk( return GenericStreamingChunk(
text="", text="",
@ -42,9 +55,8 @@ class BaseModelResponseIterator:
index=0, index=0,
tool_use=None, tool_use=None,
) )
elif str_line.startswith("data:"): elif stripped_json_chunk:
data_json = json.loads(str_line[5:]) return self.chunk_parser(chunk=stripped_json_chunk)
return self.chunk_parser(chunk=data_json)
else: else:
return GenericStreamingChunk( return GenericStreamingChunk(
text="", text="",
@ -85,6 +97,7 @@ class BaseModelResponseIterator:
async def __anext__(self): async def __anext__(self):
try: try:
chunk = await self.async_response_iterator.__anext__() chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e:
@ -99,7 +112,9 @@ class BaseModelResponseIterator:
str_line = str_line[index:] str_line = str_line[index:]
# chunk is a str at this point # chunk is a str at this point
return self._handle_string_chunk(str_line=str_line) chunk = self._handle_string_chunk(str_line=str_line)
return chunk
except StopAsyncIteration: except StopAsyncIteration:
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e:

View file

@ -3,6 +3,7 @@ Utility functions for base LLM classes.
""" """
import copy import copy
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
@ -10,8 +11,8 @@ from openai.lib import _parsing, _pydantic
from pydantic import BaseModel from pydantic import BaseModel
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.types.utils import ProviderSpecificModelInfo from litellm.types.utils import Message, ProviderSpecificModelInfo
class BaseLLMModelInfo(ABC): class BaseLLMModelInfo(ABC):
@ -55,6 +56,32 @@ class BaseLLMModelInfo(ABC):
pass pass
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[Message]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if isinstance(args, dict) and (values := args.get("values")) is not None:
_message = Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return Message(content=json_mode_content_str)
return None
def _dict_to_response_format_helper( def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None response_format: dict, ref_template: Optional[str] = None
) -> dict: ) -> dict:

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL, BEDROCK_MAX_POLICY_SIZE
from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
@ -381,7 +381,7 @@ class BaseAWSLLM:
"region_name": aws_region_name, "region_name": aws_region_name,
} }
if sts_response["PackedPolicySize"] > 75: if sts_response["PackedPolicySize"] > BEDROCK_MAX_POLICY_SIZE:
verbose_logger.warning( verbose_logger.warning(
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
) )

View file

@ -368,6 +368,7 @@ class BaseLLMHTTPHandler:
else None else None
), ),
litellm_params=litellm_params, litellm_params=litellm_params,
json_mode=json_mode,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -420,6 +421,7 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[HTTPHandler] = None, client: Optional[HTTPHandler] = None,
json_mode: bool = False,
) -> Tuple[Any, dict]: ) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler): if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client( sync_httpx_client = _get_httpx_client(
@ -447,11 +449,15 @@ class BaseLLMHTTPHandler:
if fake_stream is True: if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True streaming_response=response.json(),
sync_stream=True,
json_mode=json_mode,
) )
else: else:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True streaming_response=response.iter_lines(),
sync_stream=True,
json_mode=json_mode,
) )
# LOGGING # LOGGING

View file

@ -1,84 +0,0 @@
"""
Handles the chat completion request for Databricks
"""
from typing import Callable, List, Optional, Union, cast
from httpx._config import Timeout
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CustomStreamingDecoder
from litellm.utils import ModelResponse
from ...openai_like.chat.handler import OpenAILikeChatHandler
from ..common_utils import DatabricksBase
from .transformation import DatabricksConfig
class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def completion(
self,
*,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
timeout: Optional[Union[float, Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
messages = DatabricksConfig()._transform_messages(
messages=cast(List[AllMessageValues], messages), model=model
)
api_base, headers = self.databricks_validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
headers=headers,
)
if optional_params.get("stream") is True:
fake_stream = DatabricksConfig()._should_fake_stream(optional_params)
else:
fake_stream = False
return super().completion(
model=model,
messages=messages,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
custom_endpoint=True,
streaming_decoder=streaming_decoder,
fake_stream=fake_stream,
)

View file

@ -2,21 +2,68 @@
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
""" """
from typing import List, Optional, Union from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)
import httpx
from pydantic import BaseModel from pydantic import BaseModel
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_handle_invalid_parallel_tool_calls,
_should_convert_tool_call_to_json_mode,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import ( from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion, handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages, strip_name_from_messages,
) )
from litellm.types.llms.openai import AllMessageValues from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.utils import ProviderField from litellm.types.llms.anthropic import AnthropicMessagesTool
from litellm.types.llms.databricks import (
AllDatabricksContentValues,
DatabricksChoice,
DatabricksFunction,
DatabricksResponse,
DatabricksTool,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionThinkingBlock,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
Message,
ModelResponse,
ModelResponseStream,
ProviderField,
Usage,
)
from ...anthropic.chat.transformation import AnthropicConfig
from ...openai_like.chat.transformation import OpenAILikeChatConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig
from ..common_utils import DatabricksBase, DatabricksException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class DatabricksConfig(OpenAILikeChatConfig): class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
""" """
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
""" """
@ -63,6 +110,39 @@ class DatabricksConfig(OpenAILikeChatConfig):
), ),
] ]
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_base, headers = self.databricks_validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=False,
headers=headers,
)
# Ensure Content-Type header is set
headers["Content-Type"] = "application/json"
return headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = self._get_api_base(api_base)
complete_url = f"{api_base}/chat/completions"
return complete_url
def get_supported_openai_params(self, model: Optional[str] = None) -> list: def get_supported_openai_params(self, model: Optional[str] = None) -> list:
return [ return [
"stream", "stream",
@ -75,8 +155,98 @@ class DatabricksConfig(OpenAILikeChatConfig):
"response_format", "response_format",
"tools", "tools",
"tool_choice", "tool_choice",
"reasoning_effort",
"thinking",
] ]
def convert_anthropic_tool_to_databricks_tool(
self, tool: Optional[AnthropicMessagesTool]
) -> Optional[DatabricksTool]:
if tool is None:
return None
return DatabricksTool(
type="function",
function=DatabricksFunction(
name=tool["name"],
parameters=cast(dict, tool.get("input_schema") or {}),
),
)
def map_response_format_to_databricks_tool(
self,
model: str,
value: Optional[dict],
optional_params: dict,
is_thinking_enabled: bool,
) -> Optional[DatabricksTool]:
if value is None:
return None
tool = self.map_response_format_to_anthropic_tool(
value, optional_params, is_thinking_enabled
)
databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool)
return databricks_tool
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
replace_max_completion_tokens_with_max_tokens: bool = True,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
mapped_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
if (
"max_completion_tokens" in non_default_params
and replace_max_completion_tokens_with_max_tokens
):
mapped_params["max_tokens"] = non_default_params[
"max_completion_tokens"
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
mapped_params.pop("max_completion_tokens", None)
if "response_format" in non_default_params and "claude" in model:
_tool = self.map_response_format_to_databricks_tool(
model,
non_default_params["response_format"],
mapped_params,
is_thinking_enabled,
)
if _tool is not None:
self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
optional_params["json_mode"] = True
if not is_thinking_enabled:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=RESPONSE_FORMAT_TOOL_NAME
),
)
optional_params["tool_choice"] = _tool_choice
optional_params.pop(
"response_format", None
) # unsupported for claude models - if json_schema -> convert to tool call
if "reasoning_effort" in non_default_params and "claude" in model:
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
non_default_params.get("reasoning_effort")
)
## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=mapped_params
)
return mapped_params
def _should_fake_stream(self, optional_params: dict) -> bool: def _should_fake_stream(self, optional_params: dict) -> bool:
""" """
Databricks doesn't support 'response_format' while streaming Databricks doesn't support 'response_format' while streaming
@ -104,3 +274,259 @@ class DatabricksConfig(OpenAILikeChatConfig):
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
new_messages = strip_name_from_messages(new_messages) new_messages = strip_name_from_messages(new_messages)
return super()._transform_messages(messages=new_messages, model=model) return super()._transform_messages(messages=new_messages, model=model)
@staticmethod
def extract_content_str(
content: Optional[AllDatabricksContentValues],
) -> Optional[str]:
if content is None:
return None
if isinstance(content, str):
return content
elif isinstance(content, list):
content_str = ""
for item in content:
if item["type"] == "text":
content_str += item["text"]
return content_str
else:
raise Exception(f"Unsupported content type: {type(content)}")
@staticmethod
def extract_reasoning_content(
content: Optional[AllDatabricksContentValues],
) -> Tuple[Optional[str], Optional[List[ChatCompletionThinkingBlock]]]:
"""
Extract and return the reasoning content and thinking blocks
"""
if content is None:
return None, None
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
reasoning_content: Optional[str] = None
if isinstance(content, list):
for item in content:
if item["type"] == "reasoning":
for sum in item["summary"]:
if reasoning_content is None:
reasoning_content = ""
reasoning_content += sum["text"]
thinking_block = ChatCompletionThinkingBlock(
type="thinking",
thinking=sum["text"],
signature=sum["signature"],
)
if thinking_blocks is None:
thinking_blocks = []
thinking_blocks.append(thinking_block)
return reasoning_content, thinking_blocks
def _transform_choices(
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None
) -> List[Choices]:
transformed_choices = []
for choice in choices:
## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None)
if tool_calls is not None:
_openai_tool_calls = []
for _tc in tool_calls:
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
_openai_tool_calls.append(_openai_tc)
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
_openai_tool_calls
)
if fixed_tool_calls is not None:
tool_calls = fixed_tool_calls
translated_message: Optional[Message] = None
finish_reason: Optional[str] = None
if tool_calls and _should_convert_tool_call_to_json_mode(
tool_calls=tool_calls,
convert_tool_call_to_json_mode=json_mode,
):
# to support response_format on claude models
json_mode_content_str: Optional[str] = (
str(tool_calls[0]["function"].get("arguments", "")) or None
)
if json_mode_content_str is not None:
translated_message = Message(content=json_mode_content_str)
finish_reason = "stop"
if translated_message is None:
## get the content str
content_str = DatabricksConfig.extract_content_str(
choice["message"]["content"]
)
## get the reasoning content
(
reasoning_content,
thinking_blocks,
) = DatabricksConfig.extract_reasoning_content(
choice["message"].get("content")
)
translated_message = Message(
role="assistant",
content=content_str,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
tool_calls=choice["message"].get("tool_calls"),
)
if finish_reason is None:
finish_reason = choice["finish_reason"]
translated_choice = Choices(
finish_reason=finish_reason,
index=choice["index"],
message=translated_message,
logprobs=None,
enhancements=None,
)
transformed_choices.append(translated_choice)
return transformed_choices
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = DatabricksResponse(**raw_response.json()) # type: ignore
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise DatabricksException(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
model_response.model = completion_response["model"]
model_response.id = completion_response["id"]
model_response.created = completion_response["created"]
setattr(model_response, "usage", Usage(**completion_response["usage"]))
model_response.choices = self._transform_choices( # type: ignore
choices=completion_response["choices"],
json_mode=json_mode,
)
return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return DatabricksChatResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class DatabricksChatResponseIterator(BaseModelResponseIterator):
def __init__(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
super().__init__(streaming_response, sync_stream)
self.json_mode = json_mode
self._last_function_name = None # Track the last seen function name
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
try:
translated_choices = []
for choice in chunk["choices"]:
tool_calls = choice["delta"].get("tool_calls")
if tool_calls and self.json_mode:
# 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME
# 2. If no function name, just args -> check last function name (saved via state variable)
# 3. Convert args to json
# 4. Convert json to message
# 5. Set content to message.content
# 6. Set tool_calls to None
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.base_llm.base_utils import (
_convert_tool_response_to_message,
)
# Check if this chunk has a function name
function_name = tool_calls[0].get("function", {}).get("name")
if function_name is not None:
self._last_function_name = function_name
# If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME
# or this chunk has the matching function name
if (
self._last_function_name == RESPONSE_FORMAT_TOOL_NAME
or function_name == RESPONSE_FORMAT_TOOL_NAME
):
# Convert tool calls to message format
message = _convert_tool_response_to_message(tool_calls)
if message is not None:
if message.content == "{}": # empty json
message.content = ""
choice["delta"]["content"] = message.content
choice["delta"]["tool_calls"] = None
# extract the content str
content_str = DatabricksConfig.extract_content_str(
choice["delta"].get("content")
)
# extract the reasoning content
(
reasoning_content,
thinking_blocks,
) = DatabricksConfig.extract_reasoning_content(
choice["delta"]["content"]
)
choice["delta"]["content"] = content_str
choice["delta"]["reasoning_content"] = reasoning_content
choice["delta"]["thinking_blocks"] = thinking_blocks
translated_choices.append(choice)
return ModelResponseStream(
id=chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=translated_choices,
)
except KeyError as e:
raise DatabricksException(
message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}",
status_code=400,
)
except Exception as e:
raise e

View file

@ -1,9 +1,35 @@
from typing import Literal, Optional, Tuple from typing import Literal, Optional, Tuple
from .exceptions import DatabricksError from litellm.llms.base_llm.chat.transformation import BaseLLMException
class DatabricksException(BaseLLMException):
pass
class DatabricksBase: class DatabricksBase:
def _get_api_base(self, api_base: Optional[str]) -> str:
if api_base is None:
try:
from databricks.sdk import WorkspaceClient
databricks_client = WorkspaceClient()
api_base = (
api_base or f"{databricks_client.config.host}/serving-endpoints"
)
return api_base
except ImportError:
raise DatabricksException(
status_code=400,
message=(
"Either set the DATABRICKS_API_BASE and DATABRICKS_API_KEY environment variables, "
"or install the databricks-sdk Python library."
),
)
return api_base
def _get_databricks_credentials( def _get_databricks_credentials(
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
@ -23,7 +49,7 @@ class DatabricksBase:
return api_base, headers return api_base, headers
except ImportError: except ImportError:
raise DatabricksError( raise DatabricksException(
status_code=400, status_code=400,
message=( message=(
"If the Databricks base URL and API key are not set, the databricks-sdk " "If the Databricks base URL and API key are not set, the databricks-sdk "
@ -41,9 +67,9 @@ class DatabricksBase:
custom_endpoint: Optional[bool], custom_endpoint: Optional[bool],
headers: Optional[dict], headers: Optional[dict],
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
if api_key is None and headers is None: if api_key is None and not headers: # handle empty headers
if custom_endpoint is not None: if custom_endpoint is not None:
raise DatabricksError( raise DatabricksException(
status_code=400, status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
) )
@ -54,7 +80,7 @@ class DatabricksBase:
if api_base is None: if api_base is None:
if custom_endpoint: if custom_endpoint:
raise DatabricksError( raise DatabricksException(
status_code=400, status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
) )

View file

@ -1,12 +0,0 @@
import httpx
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs

View file

@ -1,6 +1,7 @@
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import litellm import litellm
from litellm.constants import MIN_NON_ZERO_TEMPERATURE
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
@ -84,7 +85,7 @@ class DeepInfraConfig(OpenAIGPTConfig):
and value == 0 and value == 0
and model == "mistralai/Mistral-7B-Instruct-v0.1" and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0 ): # this model does no support temperature == 0
value = 0.0001 # close to 0 value = MIN_NON_ZERO_TEMPERATURE # close to 0
if param == "tool_choice": if param == "tool_choice":
if ( if (
value != "auto" and value != "none" value != "auto" and value != "none"

View file

@ -4,6 +4,12 @@ For calculating cost of fireworks ai serverless inference models.
from typing import Tuple from typing import Tuple
from litellm.constants import (
FIREWORKS_AI_16_B,
FIREWORKS_AI_56_B_MOE,
FIREWORKS_AI_80_B,
FIREWORKS_AI_176_B_MOE,
)
from litellm.types.utils import Usage from litellm.types.utils import Usage
from litellm.utils import get_model_info from litellm.utils import get_model_info
@ -25,9 +31,9 @@ def get_base_model_for_pricing(model_name: str) -> str:
moe_match = re.search(r"(\d+)x(\d+)b", model_name) moe_match = re.search(r"(\d+)x(\d+)b", model_name)
if moe_match: if moe_match:
total_billion = int(moe_match.group(1)) * int(moe_match.group(2)) total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
if total_billion <= 56: if total_billion <= FIREWORKS_AI_56_B_MOE:
return "fireworks-ai-moe-up-to-56b" return "fireworks-ai-moe-up-to-56b"
elif total_billion <= 176: elif total_billion <= FIREWORKS_AI_176_B_MOE:
return "fireworks-ai-56b-to-176b" return "fireworks-ai-56b-to-176b"
# Check for standard models in the form <number>b # Check for standard models in the form <number>b
@ -37,9 +43,9 @@ def get_base_model_for_pricing(model_name: str) -> str:
params_billion = float(params_match) params_billion = float(params_match)
# Determine the category based on the number of parameters # Determine the category based on the number of parameters
if params_billion <= 16.0: if params_billion <= FIREWORKS_AI_16_B:
return "fireworks-ai-up-to-16b" return "fireworks-ai-up-to-16b"
elif params_billion <= 80.0: elif params_billion <= FIREWORKS_AI_80_B:
return "fireworks-ai-16b-80b" return "fireworks-ai-16b-80b"
# If no matches, return the original model_name # If no matches, return the original model_name

View file

@ -81,6 +81,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
"stop", "stop",
"logprobs", "logprobs",
"frequency_penalty", "frequency_penalty",
"modalities",
] ]
def map_openai_params( def map_openai_params(

View file

@ -1,769 +0,0 @@
## Uses the huggingface text generation inference API
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
get_args,
)
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.huggingface.chat.transformation import (
HuggingfaceChatConfig as HuggingfaceConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.types.utils import ModelResponse
from ...base import BaseLLM
from ..common_utils import HuggingfaceError
hf_chat_config = HuggingfaceConfig()
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise HuggingfaceError(
status_code=e.response.status_code,
message=str(await e.response.aread()),
headers=cast(dict, error_headers) if error_headers else None,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise HuggingfaceError(status_code=500, message=str(e))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return response.aiter_lines(), response.headers
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
api_base: Optional[str],
model_response: ModelResponse,
print_verbose: Callable,
timeout: float,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion: bool = False,
logger_fn=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers: dict = {},
):
super().completion()
exception_mapping_worked = False
try:
task, model = hf_chat_config.get_hf_task_for_model(model)
litellm_params["task"] = task
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
data = hf_chat_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(
api_base=completion_url,
data=data,
headers=headers,
model_response=model_response,
encoding=encoding,
model=model,
optional_params=optional_params,
timeout=timeout,
litellm_params=litellm_params,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
client=(
client
if client is not None
and isinstance(client, AsyncHTTPHandler)
else None
),
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] is True:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return hf_chat_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except httpx.HTTPStatusError as e:
raise HuggingfaceError(
status_code=e.response.status_code,
message=e.response.text,
headers=e.response.headers,
)
except HuggingfaceError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
encoding: Any,
model: str,
optional_params: dict,
litellm_params: dict,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: str,
messages: List[AllMessageValues],
client: Optional[AsyncHTTPHandler] = None,
):
response: Optional[httpx.Response] = None
try:
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE
)
### ASYNC COMPLETION
http_response = await client.post(
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
)
response = http_response
return hf_chat_config.transform_response(
model=model,
raw_response=http_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif isinstance(e, HuggingfaceError):
raise e
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
headers=response.headers,
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
messages: List[AllMessageValues],
model: str,
timeout: float,
client: Optional[AsyncHTTPHandler] = None,
):
completion_stream, _ = await make_call(
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=False,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
return streamwrapper
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = HuggingfaceConfig().get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> EmbeddingResponse:
super().embedding()
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def _transform_logprobs(
self, hf_response: Optional[List]
) -> Optional[TextCompletionLogprobs]:
"""
Transform Hugging Face logprobs to OpenAI.Completion() format
"""
if hf_response is None:
return None
# Initialize an empty list for the transformed logprobs
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
text_offset=[],
token_logprobs=[],
tokens=[],
top_logprobs=[],
)
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
cast(List[int], _logprob.text_offset).append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
return _logprob

View file

@ -1,27 +1,10 @@
import json import logging
import os import os
import time from typing import TYPE_CHECKING, Any, List, Optional, Union
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx import httpx
import litellm from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
@ -30,176 +13,98 @@ if TYPE_CHECKING:
else: else:
LoggingClass = Any LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
tgi_models_cache = None from ...openai.chat.gpt_transformation import OpenAIGPTConfig
conv_models_cache = None from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
class HuggingfaceChatConfig(BaseConfig): logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
class HuggingFaceChatConfig(OpenAIGPTConfig):
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
""" """
hf_task: Optional[ def validate_environment(
hf_tasks
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self, self,
best_of: Optional[int] = None, headers: dict,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str, model: str,
drop_params: bool, messages: List[AllMessageValues],
) -> Dict: optional_params: dict,
for param, value in non_default_params.items(): api_key: Optional[str] = None,
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None api_base: Optional[str] = None,
if param == "temperature": ) -> dict:
if value == 0.0 or value == 0: default_headers = {
# hugging face exception raised when temp==0 "content-type": "application/json",
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive }
value = 0.01 if api_key is not None:
optional_params["temperature"] = value default_headers["Authorization"] = f"Bearer {api_key}"
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params headers = {**headers, **default_headers}
def get_hf_api_key(self) -> Optional[str]: return headers
return get_secret_str("HUGGINGFACE_API_KEY")
def read_tgi_conv_models(self): def get_error_class(
try: self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
global tgi_models_cache, conv_models_cache ) -> BaseLLMException:
# Check if the cache is already populated return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, "r") as file: def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
for line in file: """
tgi_models.add(line.strip()) Get the API base for the Huggingface API.
# Cache the set for future use Do not add the chat/embedding/rerank extension here. Let the handler do this.
tgi_models_cache = tgi_models """
if model.startswith(("http://", "https://")):
base_url = model
elif base_url is None:
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
return base_url
# If not, read the file and populate the cache def get_complete_url(
file_path = os.path.join( self,
script_directory, api_base: Optional[str],
"huggingface_llms_metadata", api_key: Optional[str],
"hf_conversational_models.txt", model: str,
) optional_params: dict,
conv_models = set() litellm_params: dict,
with open(file_path, "r") as file: stream: Optional[bool] = None,
for line in file: ) -> str:
conv_models.add(line.strip()) """
# Cache the set for future use Get the complete URL for the API call.
conv_models_cache = conv_models For provider-specific routing through huggingface
return tgi_models, conv_models """
except Exception: # 1. Check if api_base is provided
return set(), set() if api_base is not None:
complete_url = api_base
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]: elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
# read text file, cast it to set complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" elif model.startswith(("http://", "https://")):
if model.split("/")[0] in hf_task_list: complete_url = model
split_model = model.split("/", 1) # 4. Default construction with provider
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
else: else:
return "text-generation-inference", model # default to tgi # Parse provider and model
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
else:
provider = "hf-inference"
if provider == "hf-inference":
route = f"{provider}/models/{model}/v1/chat/completions"
elif provider == "novita":
route = f"{provider}/chat/completions"
else:
route = f"{provider}/v1/chat/completions"
complete_url = f"{BASE_URL}/{route}"
# Ensure URL doesn't end with a slash
complete_url = complete_url.rstrip("/")
return complete_url
def transform_request( def transform_request(
self, self,
@ -209,381 +114,28 @@ class HuggingfaceChatConfig(BaseConfig):
litellm_params: dict, litellm_params: dict,
headers: dict, headers: dict,
) -> dict: ) -> dict:
task = litellm_params.get("task", None) if "max_retries" in optional_params:
## VALIDATE API FORMAT logger.warning("`max_retries` is not supported. It will be ignored.")
if task is None or not isinstance(task, str) or task not in hf_task_list: optional_params.pop("max_retries", None)
raise Exception( first_part, remaining = model.split("/", 1)
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks) if "/" in remaining:
) provider = first_part
model_id = remaining
## Load Config
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
else: else:
# Non TGI and Conversational llms provider = "hf-inference"
# We need this branch, it removes 'details' and 'return_full_text' from params model_id = model
if model in litellm.custom_prompt_dict: provider_mapping = _fetch_inference_provider_mapping(model_id)
# check if the model has a registered custom prompt if provider not in provider_mapping:
model_prompt_details = litellm.custom_prompt_dict[model] raise HuggingFaceError(
prompt = custom_prompt( message=f"Model {model_id} is not supported for provider {provider}",
role_dict=model_prompt_details.get("roles", {}), status_code=404,
initial_prompt_value=model_prompt_details.get( headers={},
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingfaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
) )
else: provider_mapping = provider_mapping[provider]
if ( if provider_mapping["status"] == "staging":
isinstance(completion_response, list) logger.warning(
and len(completion_response[0]["generated_text"]) > 0 f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
) )
else: mapped_model = provider_mapping["providerId"]
## LOGGING messages = self._transform_messages(messages=messages, model=mapped_model)
logging_obj.post_call( return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingfaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingfaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)

View file

@ -1,18 +1,30 @@
import os
from functools import lru_cache
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import httpx import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
HF_HUB_URL = "https://huggingface.co"
class HuggingfaceError(BaseLLMException):
class HuggingFaceError(BaseLLMException):
def __init__( def __init__(
self, self,
status_code: int, status_code,
message: str, message,
headers: Optional[Union[dict, httpx.Headers]] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
): ):
super().__init__(status_code=status_code, message=message, headers=headers) super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
hf_tasks = Literal[ hf_tasks = Literal[
@ -43,3 +55,48 @@ def output_parser(generated_text: str):
if generated_text.endswith(token): if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text return generated_text
@lru_cache(maxsize=128)
def _fetch_inference_provider_mapping(model: str) -> dict:
"""
Fetch provider mappings for a model from the Hugging Face Hub.
Args:
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
Returns:
dict: The inference provider mapping for the model
Raises:
ValueError: If no provider mapping is found
HuggingFaceError: If the API request fails
"""
headers = {"Accept": "application/json"}
if os.getenv("HUGGINGFACE_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
path = f"{HF_HUB_URL}/api/models/{model}"
params = {"expand": ["inferenceProviderMapping"]}
try:
response = httpx.get(path, headers=headers, params=params)
response.raise_for_status()
provider_mapping = response.json().get("inferenceProviderMapping")
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
except httpx.HTTPError as e:
if hasattr(e, "response"):
status_code = getattr(e.response, "status_code", 500)
headers = getattr(e.response, "headers", {})
else:
status_code = 500
headers = {}
raise HuggingFaceError(
message=f"Failed to fetch provider mapping: {str(e)}",
status_code=status_code,
headers=headers,
)

View file

@ -0,0 +1,421 @@
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from ...base import BaseLLM
from ..common_utils import HuggingFaceError
from .transformation import HuggingFaceEmbeddingConfig
config = HuggingFaceEmbeddingConfig()
HF_HUB_URL = "https://huggingface.co"
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class HuggingFaceEmbedding(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = config.get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingFaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> EmbeddingResponse:
super().embedding()
headers = config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
)
task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View file

@ -0,0 +1,589 @@
import json
import os
import time
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingFaceError, hf_task_list, hf_tasks, output_parser
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
tgi_models_cache = None
conv_models_cache = None
class HuggingFaceEmbeddingConfig(BaseConfig):
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[
hf_tasks
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return get_secret_str("HUGGINGFACE_API_KEY")
def read_tgi_conv_models(self):
try:
global tgi_models_cache, conv_models_cache
# Check if the cache is already populated
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except Exception:
return set(), set()
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
else:
return "text-generation-inference", model # default to tgi
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
task = litellm_params.get("task", None)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
## Load Config
config = litellm.HuggingFaceEmbeddingConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingFaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
)
else:
if (
isinstance(completion_response, list)
and len(completion_response[0]["generated_text"]) > 0
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
)
else:
## LOGGING
logging_obj.post_call(
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingFaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingFaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)

View file

@ -34,7 +34,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
return api_base, dynamic_api_key return api_base, dynamic_api_key
@staticmethod @staticmethod
def _convert_tool_response_to_message( def _json_mode_convert_tool_response_to_message(
message: ChatCompletionAssistantMessage, json_mode: bool message: ChatCompletionAssistantMessage, json_mode: bool
) -> ChatCompletionAssistantMessage: ) -> ChatCompletionAssistantMessage:
""" """
@ -88,8 +88,10 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
if json_mode: if json_mode:
for choice in response_json["choices"]: for choice in response_json["choices"]:
message = OpenAILikeChatConfig._convert_tool_response_to_message( message = (
choice.get("message"), json_mode OpenAILikeChatConfig._json_mode_convert_tool_response_to_message(
choice.get("message"), json_mode
)
) )
choice["message"] = message choice["message"] = message

View file

@ -6,12 +6,13 @@ Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
Docs: https://openrouter.ai/docs/parameters Docs: https://openrouter.ai/docs/parameters
""" """
from typing import Any, AsyncIterator, Iterator, Optional, Union from typing import Any, AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openrouter import OpenRouterErrorMessage from litellm.types.llms.openrouter import OpenRouterErrorMessage
from litellm.types.utils import ModelResponse, ModelResponseStream from litellm.types.utils import ModelResponse, ModelResponseStream
@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig):
] = extra_body # openai client supports `extra_body` param ] = extra_body # openai client supports `extra_body` param
return mapped_openai_params return mapped_openai_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the overall request to be sent to the API.
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
extra_body = optional_params.pop("extra_body", {})
response = super().transform_request(
model, messages, optional_params, litellm_params, headers
)
response.update(extra_body)
return response
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:

View file

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
from httpx import Headers, Response from httpx import Headers, Response
from litellm.constants import DEFAULT_MAX_TOKENS
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
@ -27,7 +28,7 @@ class PredibaseConfig(BaseConfig):
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: bool = True # enables returning logprobs + best of details: bool = True # enables returning logprobs + best of
max_new_tokens: int = ( max_new_tokens: int = (
256 # openai default - requests hang if max_new_tokens not given DEFAULT_MAX_TOKENS # openai default - requests hang if max_new_tokens not given
) )
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
return_full_text: Optional[ return_full_text: Optional[

Some files were not shown because too many files have changed in this diff Show more