mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into fix_budget_limits
This commit is contained in:
commit
3f8b827f79
290 changed files with 10057 additions and 4647 deletions
|
@ -10,6 +10,6 @@ anthropic
|
||||||
orjson==3.9.15
|
orjson==3.9.15
|
||||||
pydantic==2.10.2
|
pydantic==2.10.2
|
||||||
google-cloud-aiplatform==1.43.0
|
google-cloud-aiplatform==1.43.0
|
||||||
fastapi-sso==0.10.0
|
fastapi-sso==0.16.0
|
||||||
uvloop==0.21.0
|
uvloop==0.21.0
|
||||||
mcp==1.5.0 # for MCP server
|
mcp==1.5.0 # for MCP server
|
||||||
|
|
318
cookbook/LiteLLM_HuggingFace.ipynb
vendored
318
cookbook/LiteLLM_HuggingFace.ipynb
vendored
|
@ -6,8 +6,9 @@
|
||||||
"id": "9dKM5k8qsMIj"
|
"id": "9dKM5k8qsMIj"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## LiteLLM HuggingFace\n",
|
"## LiteLLM Hugging Face\n",
|
||||||
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface"
|
"\n",
|
||||||
|
"Docs for huggingface: https://docs.litellm.ai/docs/providers/huggingface\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -27,23 +28,18 @@
|
||||||
"id": "yp5UXRqtpu9f"
|
"id": "yp5UXRqtpu9f"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Hugging Face Free Serverless Inference API\n",
|
"## Serverless Inference Providers\n",
|
||||||
"Read more about the Free Serverless Inference API here: https://huggingface.co/docs/api-inference.\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"In order to use litellm to call Serverless Inference API:\n",
|
"Read more about Inference Providers here: https://huggingface.co/blog/inference-providers.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* Browse Serverless Inference compatible models here: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation.\n",
|
"In order to use litellm with Hugging Face Inference Providers, you need to set `model=huggingface/<provider>/<model-id>`.\n",
|
||||||
"* Copy the model name from hugging face\n",
|
|
||||||
"* Set `model = \"huggingface/<model-name>\"`\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"Example set `model=huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct` to call `meta-llama/Meta-Llama-3.1-8B-Instruct`\n",
|
"Example: `huggingface/together/deepseek-ai/DeepSeek-R1` to run DeepSeek-R1 (https://huggingface.co/deepseek-ai/DeepSeek-R1) through Together AI.\n"
|
||||||
"\n",
|
|
||||||
"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
|
@ -51,107 +47,18 @@
|
||||||
"id": "Pi5Oww8gpCUm",
|
"id": "Pi5Oww8gpCUm",
|
||||||
"outputId": "659a67c7-f90d-4c06-b94e-2c4aa92d897a"
|
"outputId": "659a67c7-f90d-4c06-b94e-2c4aa92d897a"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"ModelResponse(id='chatcmpl-c54dfb68-1491-4d68-a4dc-35e603ea718a', choices=[Choices(finish_reason='eos_token', index=0, message=Message(content=\"I'm just a computer program, so I don't have feelings, but thank you for asking! How can I assist you today?\", role='assistant', tool_calls=None, function_call=None))], created=1724858285, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=27, prompt_tokens=47, total_tokens=74))\n",
|
|
||||||
"ModelResponse(id='chatcmpl-d2ae38e6-4974-431c-bb9b-3fa3f95e5a6d', choices=[Choices(finish_reason='length', index=0, message=Message(content=\"\\n\\nI’m doing well, thank you. I’ve been keeping busy with work and some personal projects. How about you?\\n\\nI'm doing well, thank you. I've been enjoying some time off and catching up on some reading. How can I assist you today?\\n\\nI'm looking for a good book to read. Do you have any recommendations?\\n\\nOf course! Here are a few book recommendations across different genres:\\n\\n1.\", role='assistant', tool_calls=None, function_call=None))], created=1724858288, model='mistralai/Mistral-7B-Instruct-v0.3', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=85, prompt_tokens=6, total_tokens=91))\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import litellm\n",
|
"from litellm import completion\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n",
|
"# You can create a HF token here: https://huggingface.co/settings/tokens\n",
|
||||||
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
|
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Call https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct\n",
|
"# Call DeepSeek-R1 model through Together AI\n",
|
||||||
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
|
"response = completion(\n",
|
||||||
"response = litellm.completion(\n",
|
" model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
|
||||||
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
|
" messages=[{\"content\": \"How many r's are in the word `strawberry`?\", \"role\": \"user\"}],\n",
|
||||||
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
|
|
||||||
")\n",
|
|
||||||
"print(response)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# Call https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3\n",
|
|
||||||
"response = litellm.completion(\n",
|
|
||||||
" model=\"huggingface/mistralai/Mistral-7B-Instruct-v0.3\",\n",
|
|
||||||
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}]\n",
|
|
||||||
")\n",
|
|
||||||
"print(response)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "-klhAhjLtclv"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Hugging Face Dedicated Inference Endpoints\n",
|
|
||||||
"\n",
|
|
||||||
"Steps to use\n",
|
|
||||||
"* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
|
|
||||||
"* Set `api_base` to your deployed api base\n",
|
|
||||||
"* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "Lbmw8Gl_pHns",
|
|
||||||
"outputId": "ea8408bf-1cc3-4670-ecea-f12666d204a8"
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{\n",
|
|
||||||
" \"object\": \"chat.completion\",\n",
|
|
||||||
" \"choices\": [\n",
|
|
||||||
" {\n",
|
|
||||||
" \"finish_reason\": \"length\",\n",
|
|
||||||
" \"index\": 0,\n",
|
|
||||||
" \"message\": {\n",
|
|
||||||
" \"content\": \"\\n\\nI am doing well, thank you for asking. How about you?\\nI am doing\",\n",
|
|
||||||
" \"role\": \"assistant\",\n",
|
|
||||||
" \"logprobs\": -8.9481967812\n",
|
|
||||||
" }\n",
|
|
||||||
" }\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"id\": \"chatcmpl-74dc9d89-3916-47ce-9bea-b80e66660f77\",\n",
|
|
||||||
" \"created\": 1695871068.8413374,\n",
|
|
||||||
" \"model\": \"glaiveai/glaive-coder-7b\",\n",
|
|
||||||
" \"usage\": {\n",
|
|
||||||
" \"prompt_tokens\": 6,\n",
|
|
||||||
" \"completion_tokens\": 18,\n",
|
|
||||||
" \"total_tokens\": 24\n",
|
|
||||||
" }\n",
|
|
||||||
"}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import litellm\n",
|
|
||||||
"\n",
|
|
||||||
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
|
|
||||||
"\n",
|
|
||||||
"# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b\n",
|
|
||||||
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
|
|
||||||
"# set api base to your deployed api endpoint from hugging face\n",
|
|
||||||
"response = litellm.completion(\n",
|
|
||||||
" model=\"huggingface/glaiveai/glaive-coder-7b\",\n",
|
|
||||||
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n",
|
|
||||||
" api_base=\"https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud\"\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"print(response)"
|
"print(response)"
|
||||||
]
|
]
|
||||||
|
@ -162,13 +69,12 @@
|
||||||
"id": "EU0UubrKzTFe"
|
"id": "EU0UubrKzTFe"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## HuggingFace - Streaming (Serveless or Dedicated)\n",
|
"## Streaming\n"
|
||||||
"Set stream = True"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
|
@ -176,74 +82,147 @@
|
||||||
"id": "y-QfIvA-uJKX",
|
"id": "y-QfIvA-uJKX",
|
||||||
"outputId": "b007bb98-00d0-44a4-8264-c8a2caed6768"
|
"outputId": "b007bb98-00d0-44a4-8264-c8a2caed6768"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"<litellm.utils.CustomStreamWrapper object at 0x1278471d0>\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='I', role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'m\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' just', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' a', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' computer', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' program', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' so', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' don', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"'t\", role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' have', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' feelings', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=',', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' but', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' thank', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' for', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' asking', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='!', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' How', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' can', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' I', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' assist', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' you', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=' today', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='?', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='<|eot_id|>', role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n",
|
|
||||||
"ModelResponse(id='chatcmpl-ffeb4491-624b-4ddf-8005-60358cf67d36', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1724858353, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion.chunk', system_fingerprint=None)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import litellm\n",
|
"from litellm import completion\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Make sure to create an API_KEY with inference permissions at https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained\n",
|
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
|
||||||
"os.environ[\"HUGGINGFACE_API_KEY\"] = \"\"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Call https://huggingface.co/glaiveai/glaive-coder-7b\n",
|
"response = completion(\n",
|
||||||
"# add the 'huggingface/' prefix to the model to set huggingface as the provider\n",
|
" model=\"huggingface/together/deepseek-ai/DeepSeek-R1\",\n",
|
||||||
"# set api base to your deployed api endpoint from hugging face\n",
|
" messages=[\n",
|
||||||
"response = litellm.completion(\n",
|
" {\n",
|
||||||
" model=\"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
|
" \"role\": \"user\",\n",
|
||||||
" messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n",
|
" \"content\": \"How many r's are in the word `strawberry`?\",\n",
|
||||||
" stream=True\n",
|
" \n",
|
||||||
|
" }\n",
|
||||||
|
" ],\n",
|
||||||
|
" stream=True,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(response)\n",
|
|
||||||
"\n",
|
|
||||||
"for chunk in response:\n",
|
"for chunk in response:\n",
|
||||||
" print(chunk)"
|
" print(chunk)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## With images as input\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"id": "CKXAnK55zQRl"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"from litellm import completion\n",
|
||||||
|
"\n",
|
||||||
|
"# Set your Hugging Face Token\n",
|
||||||
|
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": [\n",
|
||||||
|
" {\"type\": \"text\", \"text\": \"What's in this image?\"},\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"image_url\",\n",
|
||||||
|
" \"image_url\": {\n",
|
||||||
|
" \"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"response = completion(\n",
|
||||||
|
" model=\"huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
")\n",
|
||||||
|
"print(response.choices[0])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tools - Function Calling\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from litellm import completion\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Set your Hugging Face Token\n",
|
||||||
|
"os.environ[\"HF_TOKEN\"] = \"hf_xxxxxx\"\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"location\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n",
|
||||||
|
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
|
||||||
|
"\n",
|
||||||
|
"response = completion(\n",
|
||||||
|
" model=\"huggingface/sambanova/meta-llama/Llama-3.1-8B-Instruct\", messages=messages, tools=tools, tool_choice=\"auto\"\n",
|
||||||
|
")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Hugging Face Dedicated Inference Endpoints\n",
|
||||||
|
"\n",
|
||||||
|
"Steps to use\n",
|
||||||
|
"\n",
|
||||||
|
"- Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/\n",
|
||||||
|
"- Set `api_base` to your deployed api base\n",
|
||||||
|
"- set the model to `huggingface/tgi` so that litellm knows it's a huggingface Deployed Inference Endpoint.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import litellm\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"response = litellm.completion(\n",
|
||||||
|
" model=\"huggingface/tgi\",\n",
|
||||||
|
" messages=[{\"content\": \"Hello, how are you?\", \"role\": \"user\"}],\n",
|
||||||
|
" api_base=\"https://my-endpoint.endpoints.huggingface.cloud/v1/\",\n",
|
||||||
|
")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -251,7 +230,8 @@
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
|
@ -264,7 +244,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
@ -27,16 +27,18 @@ os.environ["AWS_REGION_NAME"] = ""
|
||||||
|
|
||||||
|
|
||||||
# pdf url
|
# pdf url
|
||||||
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
file_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||||
|
|
||||||
# model
|
# model
|
||||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
|
||||||
image_content = [
|
file_content = [
|
||||||
{"type": "text", "text": "What's this file about?"},
|
{"type": "text", "text": "What's this file about?"},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "file",
|
||||||
"image_url": image_url, # OR {"url": image_url}
|
"file": {
|
||||||
|
"file_id": file_url,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -46,7 +48,7 @@ if not supports_pdf_input(model, None):
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": image_content}],
|
messages=[{"role": "user", "content": file_content}],
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
```
|
```
|
||||||
|
@ -80,11 +82,15 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "bedrock-model",
|
"model": "bedrock-model",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
|
{"role": "user", "content": [
|
||||||
{
|
{"type": "text", "text": "What's this file about?"},
|
||||||
"type": "image_url",
|
{
|
||||||
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
|
"type": "file",
|
||||||
}
|
"file": {
|
||||||
|
"file_id": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]},
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
@ -116,11 +122,13 @@ base64_url = f"data:application/pdf;base64,{encoded_file}"
|
||||||
# model
|
# model
|
||||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
|
||||||
image_content = [
|
file_content = [
|
||||||
{"type": "text", "text": "What's this file about?"},
|
{"type": "text", "text": "What's this file about?"},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "file",
|
||||||
"image_url": base64_url, # OR {"url": base64_url}
|
"file": {
|
||||||
|
"file_data": base64_url,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -130,11 +138,53 @@ if not supports_pdf_input(model, None):
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": image_content}],
|
messages=[{"role": "user", "content": file_content}],
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bedrock-model
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||||
|
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||||
|
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_region_name: os.environ/AWS_REGION_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "bedrock-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "What's this file about?"},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_data": "data:application/pdf;base64...",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]},
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
## Checking if a model supports pdf input
|
## Checking if a model supports pdf input
|
||||||
|
@ -200,92 +250,3 @@ Expected Response
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
## OpenAI 'file' message type
|
|
||||||
|
|
||||||
This is currently only supported for OpenAI models.
|
|
||||||
|
|
||||||
This will be supported for all providers soon.
|
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="sdk" label="SDK">
|
|
||||||
|
|
||||||
```python
|
|
||||||
import base64
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
with open("draconomicon.pdf", "rb") as f:
|
|
||||||
data = f.read()
|
|
||||||
|
|
||||||
base64_string = base64.b64encode(data).decode("utf-8")
|
|
||||||
|
|
||||||
completion = completion(
|
|
||||||
model="gpt-4o",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "file",
|
|
||||||
"file": {
|
|
||||||
"filename": "draconomicon.pdf",
|
|
||||||
"file_data": f"data:application/pdf;base64,{base64_string}",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What is the first dragon in the book?",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(completion.choices[0].message.content)
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
<TabItem value="proxy" label="PROXY">
|
|
||||||
|
|
||||||
1. Setup config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: openai-model
|
|
||||||
litellm_params:
|
|
||||||
model: gpt-4o
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
litellm --config config.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test it!
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|
||||||
-H 'Content-Type: application/json' \
|
|
||||||
-H 'Authorization: Bearer sk-1234' \
|
|
||||||
-d '{
|
|
||||||
"model": "openai-model",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": [
|
|
||||||
{
|
|
||||||
"type": "file",
|
|
||||||
"file": {
|
|
||||||
"filename": "draconomicon.pdf",
|
|
||||||
"file_data": f"data:application/pdf;base64,{base64_string}",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
|
@ -108,3 +108,75 @@ response = litellm.completion(
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
**additional_drop_params**: List or null - Is a list of openai params you want to drop when making a call to the model.
|
**additional_drop_params**: List or null - Is a list of openai params you want to drop when making a call to the model.
|
||||||
|
|
||||||
|
## Specify allowed openai params in a request
|
||||||
|
|
||||||
|
Tell litellm to allow specific openai params in a request. Use this if you get a `litellm.UnsupportedParamsError` and want to allow a param. LiteLLM will pass the param as is to the model.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="LiteLLM Python SDK">
|
||||||
|
|
||||||
|
In this example we pass `allowed_openai_params=["tools"]` to allow the `tools` param.
|
||||||
|
|
||||||
|
```python showLineNumbers title="Pass allowed_openai_params to LiteLLM Python SDK"
|
||||||
|
await litellm.acompletion(
|
||||||
|
model="azure/o_series/<my-deployment-name>",
|
||||||
|
api_key="xxxxx",
|
||||||
|
api_base=api_base,
|
||||||
|
messages=[{"role": "user", "content": "Hello! return a json object"}],
|
||||||
|
tools=[{"type": "function", "function": {"name": "get_current_time", "description": "Get the current time in a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name, e.g. San Francisco"}}, "required": ["location"]}}}]
|
||||||
|
allowed_openai_params=["tools"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM Proxy">
|
||||||
|
|
||||||
|
When using litellm proxy you can pass `allowed_openai_params` in two ways:
|
||||||
|
|
||||||
|
1. Dynamically pass `allowed_openai_params` in a request
|
||||||
|
2. Set `allowed_openai_params` on the config.yaml file for a specific model
|
||||||
|
|
||||||
|
#### Dynamically pass allowed_openai_params in a request
|
||||||
|
In this example we pass `allowed_openai_params=["tools"]` to allow the `tools` param for a request sent to the model set on the proxy.
|
||||||
|
|
||||||
|
```python showLineNumbers title="Dynamically pass allowed_openai_params in a request"
|
||||||
|
import openai
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"allowed_openai_params": ["tools"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Set allowed_openai_params on config.yaml
|
||||||
|
|
||||||
|
You can also set `allowed_openai_params` on the config.yaml file for a specific model. This means that all requests to this deployment are allowed to pass in the `tools` param.
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="Set allowed_openai_params on config.yaml"
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-o1-preview
|
||||||
|
litellm_params:
|
||||||
|
model: azure/o_series/<my-deployment-name>
|
||||||
|
api_key: xxxxx
|
||||||
|
api_base: https://openai-prod-test.openai.azure.com/openai/deployments/o1/chat/completions?api-version=2025-01-01-preview
|
||||||
|
allowed_openai_params: ["tools"]
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
|
@ -1076,32 +1076,24 @@ print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Parallel Function calling
|
### Tool Calling / Function Calling
|
||||||
|
|
||||||
See a detailed walthrough of parallel function calling with litellm [here](https://docs.litellm.ai/docs/completion/function_call)
|
See a detailed walthrough of parallel function calling with litellm [here](https://docs.litellm.ai/docs/completion/function_call)
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# set Azure env variables
|
# set Azure env variables
|
||||||
import os
|
import os
|
||||||
|
import litellm
|
||||||
|
import json
|
||||||
|
|
||||||
os.environ['AZURE_API_KEY'] = "" # litellm reads AZURE_API_KEY from .env and sends the request
|
os.environ['AZURE_API_KEY'] = "" # litellm reads AZURE_API_KEY from .env and sends the request
|
||||||
os.environ['AZURE_API_BASE'] = "https://openai-gpt-4-test-v-1.openai.azure.com/"
|
os.environ['AZURE_API_BASE'] = "https://openai-gpt-4-test-v-1.openai.azure.com/"
|
||||||
os.environ['AZURE_API_VERSION'] = "2023-07-01-preview"
|
os.environ['AZURE_API_VERSION'] = "2023-07-01-preview"
|
||||||
|
|
||||||
import litellm
|
|
||||||
import json
|
|
||||||
# Example dummy function hard coded to return the same weather
|
|
||||||
# In production, this could be your backend API or an external API
|
|
||||||
def get_current_weather(location, unit="fahrenheit"):
|
|
||||||
"""Get the current weather in a given location"""
|
|
||||||
if "tokyo" in location.lower():
|
|
||||||
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
|
|
||||||
elif "san francisco" in location.lower():
|
|
||||||
return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"})
|
|
||||||
elif "paris" in location.lower():
|
|
||||||
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
|
|
||||||
else:
|
|
||||||
return json.dumps({"location": location, "temperature": "unknown"})
|
|
||||||
|
|
||||||
## Step 1: send the conversation and available functions to the model
|
|
||||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -1125,7 +1117,7 @@ tools = [
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="azure/chatgpt-functioncalling", # model = azure/<your-azure-deployment-name>
|
model="azure/chatgpt-functioncalling", # model = azure/<your-azure-deployment-name>
|
||||||
messages=messages,
|
messages=[{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}],
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto", # auto is default, but we'll be explicit
|
tool_choice="auto", # auto is default, but we'll be explicit
|
||||||
)
|
)
|
||||||
|
@ -1134,8 +1126,49 @@ response_message = response.choices[0].message
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
print("\nTool Choice:\n", tool_calls)
|
print("\nTool Choice:\n", tool_calls)
|
||||||
```
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-gpt-3.5
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-functioncalling
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://localhost:4000/v1/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "azure-gpt-3.5",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how'\''s it going? Thinking long and hard before replying - what is the meaning of the world and life itself"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
### Spend Tracking for Azure OpenAI Models (PROXY)
|
### Spend Tracking for Azure OpenAI Models (PROXY)
|
||||||
|
|
||||||
Set base model for cost tracking azure image-gen call
|
Set base model for cost tracking azure image-gen call
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 🆕 Databricks
|
# Databricks
|
||||||
|
|
||||||
LiteLLM supports all models on Databricks
|
LiteLLM supports all models on Databricks
|
||||||
|
|
||||||
|
@ -154,7 +154,205 @@ response = completion(
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
```
|
```
|
||||||
|
|
||||||
## Passings Databricks specific params - 'instruction'
|
|
||||||
|
## Usage - Thinking / `reasoning_content`
|
||||||
|
|
||||||
|
LiteLLM translates OpenAI's `reasoning_effort` to Anthropic's `thinking` parameter. [Code](https://github.com/BerriAI/litellm/blob/23051d89dd3611a81617d84277059cd88b2df511/litellm/llms/anthropic/chat/transformation.py#L298)
|
||||||
|
|
||||||
|
| reasoning_effort | thinking |
|
||||||
|
| ---------------- | -------- |
|
||||||
|
| "low" | "budget_tokens": 1024 |
|
||||||
|
| "medium" | "budget_tokens": 2048 |
|
||||||
|
| "high" | "budget_tokens": 4096 |
|
||||||
|
|
||||||
|
|
||||||
|
Known Limitations:
|
||||||
|
- Support for passing thinking blocks back to Claude [Issue](https://github.com/BerriAI/litellm/issues/9790)
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
# set ENV variables (can also be passed in to .completion() - e.g. `api_base`, `api_key`)
|
||||||
|
os.environ["DATABRICKS_API_KEY"] = "databricks key"
|
||||||
|
os.environ["DATABRICKS_API_BASE"] = "databricks base url"
|
||||||
|
|
||||||
|
resp = completion(
|
||||||
|
model="databricks/databricks-claude-3-7-sonnet",
|
||||||
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
reasoning_effort="low",
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- model_name: claude-3-7-sonnet
|
||||||
|
litellm_params:
|
||||||
|
model: databricks/databricks-claude-3-7-sonnet
|
||||||
|
api_key: os.environ/DATABRICKS_API_KEY
|
||||||
|
api_base: os.environ/DATABRICKS_API_BASE
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||||
|
-d '{
|
||||||
|
"model": "claude-3-7-sonnet",
|
||||||
|
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
"reasoning_effort": "low"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelResponse(
|
||||||
|
id='chatcmpl-c542d76d-f675-4e87-8e5f-05855f5d0f5e',
|
||||||
|
created=1740470510,
|
||||||
|
model='claude-3-7-sonnet-20250219',
|
||||||
|
object='chat.completion',
|
||||||
|
system_fingerprint=None,
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason='stop',
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="The capital of France is Paris.",
|
||||||
|
role='assistant',
|
||||||
|
tool_calls=None,
|
||||||
|
function_call=None,
|
||||||
|
provider_specific_fields={
|
||||||
|
'citations': None,
|
||||||
|
'thinking_blocks': [
|
||||||
|
{
|
||||||
|
'type': 'thinking',
|
||||||
|
'thinking': 'The capital of France is Paris. This is a very straightforward factual question.',
|
||||||
|
'signature': 'EuYBCkQYAiJAy6...'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
thinking_blocks=[
|
||||||
|
{
|
||||||
|
'type': 'thinking',
|
||||||
|
'thinking': 'The capital of France is Paris. This is a very straightforward factual question.',
|
||||||
|
'signature': 'EuYBCkQYAiJAy6AGB...'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
reasoning_content='The capital of France is Paris. This is a very straightforward factual question.'
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
completion_tokens=68,
|
||||||
|
prompt_tokens=42,
|
||||||
|
total_tokens=110,
|
||||||
|
completion_tokens_details=None,
|
||||||
|
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||||
|
audio_tokens=None,
|
||||||
|
cached_tokens=0,
|
||||||
|
text_tokens=None,
|
||||||
|
image_tokens=None
|
||||||
|
),
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pass `thinking` to Anthropic models
|
||||||
|
|
||||||
|
You can also pass the `thinking` parameter to Anthropic models.
|
||||||
|
|
||||||
|
|
||||||
|
You can also pass the `thinking` parameter to Anthropic models.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
# set ENV variables (can also be passed in to .completion() - e.g. `api_base`, `api_key`)
|
||||||
|
os.environ["DATABRICKS_API_KEY"] = "databricks key"
|
||||||
|
os.environ["DATABRICKS_API_BASE"] = "databricks base url"
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="databricks/databricks-claude-3-7-sonnet",
|
||||||
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer $LITELLM_KEY" \
|
||||||
|
-d '{
|
||||||
|
"model": "databricks/databricks-claude-3-7-sonnet",
|
||||||
|
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Databricks Chat Completion Models
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
**We support ALL Databricks models, just set `model=databricks/<any-model-on-databricks>` as a prefix when sending litellm requests**
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
| Model Name | Command |
|
||||||
|
|----------------------------|------------------------------------------------------------------|
|
||||||
|
| databricks/databricks-claude-3-7-sonnet | `completion(model='databricks/databricks/databricks-claude-3-7-sonnet', messages=messages)` |
|
||||||
|
| databricks-meta-llama-3-1-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-70b-instruct', messages=messages)` |
|
||||||
|
| databricks-meta-llama-3-1-405b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-405b-instruct', messages=messages)` |
|
||||||
|
| databricks-dbrx-instruct | `completion(model='databricks/databricks-dbrx-instruct', messages=messages)` |
|
||||||
|
| databricks-meta-llama-3-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-70b-instruct', messages=messages)` |
|
||||||
|
| databricks-llama-2-70b-chat | `completion(model='databricks/databricks-llama-2-70b-chat', messages=messages)` |
|
||||||
|
| databricks-mixtral-8x7b-instruct | `completion(model='databricks/databricks-mixtral-8x7b-instruct', messages=messages)` |
|
||||||
|
| databricks-mpt-30b-instruct | `completion(model='databricks/databricks-mpt-30b-instruct', messages=messages)` |
|
||||||
|
| databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Embedding Models
|
||||||
|
|
||||||
|
### Passing Databricks specific params - 'instruction'
|
||||||
|
|
||||||
For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164)
|
For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164)
|
||||||
|
|
||||||
|
@ -187,27 +385,6 @@ response = litellm.embedding(
|
||||||
instruction: "Represent this sentence for searching relevant passages:"
|
instruction: "Represent this sentence for searching relevant passages:"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Supported Databricks Chat Completion Models
|
|
||||||
|
|
||||||
:::tip
|
|
||||||
|
|
||||||
**We support ALL Databricks models, just set `model=databricks/<any-model-on-databricks>` as a prefix when sending litellm requests**
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
| Model Name | Command |
|
|
||||||
|----------------------------|------------------------------------------------------------------|
|
|
||||||
| databricks-meta-llama-3-1-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-70b-instruct', messages=messages)` |
|
|
||||||
| databricks-meta-llama-3-1-405b-instruct | `completion(model='databricks/databricks-meta-llama-3-1-405b-instruct', messages=messages)` |
|
|
||||||
| databricks-dbrx-instruct | `completion(model='databricks/databricks-dbrx-instruct', messages=messages)` |
|
|
||||||
| databricks-meta-llama-3-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-70b-instruct', messages=messages)` |
|
|
||||||
| databricks-llama-2-70b-chat | `completion(model='databricks/databricks-llama-2-70b-chat', messages=messages)` |
|
|
||||||
| databricks-mixtral-8x7b-instruct | `completion(model='databricks/databricks-mixtral-8x7b-instruct', messages=messages)` |
|
|
||||||
| databricks-mpt-30b-instruct | `completion(model='databricks/databricks-mpt-30b-instruct', messages=messages)` |
|
|
||||||
| databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` |
|
|
||||||
|
|
||||||
## Supported Databricks Embedding Models
|
## Supported Databricks Embedding Models
|
||||||
|
|
||||||
:::tip
|
:::tip
|
||||||
|
|
|
@ -887,3 +887,54 @@ response = await client.chat.completions.create(
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash-exp-image-generation",
|
||||||
|
messages=[{"role": "user", "content": "Generate an image of a cat"}],
|
||||||
|
modalities=["image", "text"],
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content is not None # ".."
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash-exp-image-generation
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash-exp-image-generation
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://localhost:4000/v1/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash-exp-image-generation",
|
||||||
|
"messages": [{"role": "user", "content": "Generate an image of a cat"}],
|
||||||
|
"modalities": ["image", "text"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
161
docs/my-website/docs/providers/google_ai_studio/files.md
Normal file
161
docs/my-website/docs/providers/google_ai_studio/files.md
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# [BETA] Google AI Studio (Gemini) Files API
|
||||||
|
|
||||||
|
Use this to upload files to Google AI Studio (Gemini).
|
||||||
|
|
||||||
|
Useful to pass in large media files to Gemini's `/generateContent` endpoint.
|
||||||
|
|
||||||
|
| Action | Supported |
|
||||||
|
|----------|-----------|
|
||||||
|
| `create` | Yes |
|
||||||
|
| `delete` | No |
|
||||||
|
| `retrieve` | No |
|
||||||
|
| `list` | No |
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
from litellm import completion, create_file
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
### UPLOAD FILE ###
|
||||||
|
|
||||||
|
# Fetch the audio file and convert it to a base64 encoded string
|
||||||
|
url = "https://cdn.openai.com/API/docs/audio/alloy.wav"
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
wav_data = response.content
|
||||||
|
encoded_string = base64.b64encode(wav_data).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
file = create_file(
|
||||||
|
file=wav_data,
|
||||||
|
purpose="user_data",
|
||||||
|
extra_body={"custom_llm_provider": "gemini"},
|
||||||
|
api_key=os.getenv("GEMINI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"file: {file}")
|
||||||
|
|
||||||
|
assert file is not None
|
||||||
|
|
||||||
|
|
||||||
|
### GENERATE CONTENT ###
|
||||||
|
completion = completion(
|
||||||
|
model="gemini-2.0-flash",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is in this recording?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_id": file.id,
|
||||||
|
"filename": "my-test-name",
|
||||||
|
"format": "audio/wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(completion.choices[0].message)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "gemini-2.0-flash"
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it
|
||||||
|
|
||||||
|
```python
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch the audio file and convert it to a base64 encoded string
|
||||||
|
url = "https://cdn.openai.com/API/docs/audio/alloy.wav"
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
wav_data = response.content
|
||||||
|
encoded_string = base64.b64encode(wav_data).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
file = client.files.create(
|
||||||
|
file=wav_data,
|
||||||
|
purpose="user_data",
|
||||||
|
extra_body={"target_model_names": "gemini-2.0-flash"}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"file: {file}")
|
||||||
|
|
||||||
|
assert file is not None
|
||||||
|
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="gemini-2.0-flash",
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": "alloy", "format": "wav"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is in this recording?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_id": file.id,
|
||||||
|
"filename": "my-test-name",
|
||||||
|
"format": "audio/wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
extra_body={"drop_params": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(completion.choices[0].message)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
|
@ -2,466 +2,392 @@ import Image from '@theme/IdealImage';
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Huggingface
|
# Hugging Face
|
||||||
|
LiteLLM supports running inference across multiple services for models hosted on the Hugging Face Hub.
|
||||||
|
|
||||||
LiteLLM supports the following types of Hugging Face models:
|
- **Serverless Inference Providers** - Hugging Face offers an easy and unified access to serverless AI inference through multiple inference providers, like [Together AI](https://together.ai) and [Sambanova](https://sambanova.ai). This is the fastest way to integrate AI in your products with a maintenance-free and scalable solution. More details in the [Inference Providers documentation](https://huggingface.co/docs/inference-providers/index).
|
||||||
|
- **Dedicated Inference Endpoints** - which is a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. You can deploy your model on Hugging Face Inference Endpoints by following [these steps](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint).
|
||||||
|
|
||||||
- Serverless Inference API (free) - loaded and ready to use: https://huggingface.co/models?inference=warm&pipeline_tag=text-generation
|
|
||||||
- Dedicated Inference Endpoints (paid) - manual deployment: https://ui.endpoints.huggingface.co/
|
## Supported Models
|
||||||
- All LLMs served via Hugging Face's Inference use [Text-generation-inference](https://huggingface.co/docs/text-generation-inference).
|
|
||||||
|
### Serverless Inference Providers
|
||||||
|
You can check available models for an inference provider by going to [huggingface.co/models](https://huggingface.co/models), clicking the "Other" filter tab, and selecting your desired provider:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
For example, you can find all Fireworks supported models [here](https://huggingface.co/models?inference_provider=fireworks-ai&sort=trending).
|
||||||
|
|
||||||
|
|
||||||
|
### Dedicated Inference Endpoints
|
||||||
|
Refer to the [Inference Endpoints catalog](https://endpoints.huggingface.co/catalog) for a list of available models.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="serverless" label="Serverless Inference Providers">
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
With a single Hugging Face token, you can access inference through multiple providers. Your calls are routed through Hugging Face and the usage is billed directly to your Hugging Face account at the standard provider API rates.
|
||||||
|
|
||||||
|
Simply set the `HF_TOKEN` environment variable with your Hugging Face token, you can create one here: https://huggingface.co/settings/tokens.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export HF_TOKEN="hf_xxxxxx"
|
||||||
|
```
|
||||||
|
or alternatively, you can pass your Hugging Face token as a parameter:
|
||||||
|
```python
|
||||||
|
completion(..., api_key="hf_xxxxxx")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Getting Started
|
||||||
|
|
||||||
|
To use a Hugging Face model, specify both the provider and model you want to use in the following format:
|
||||||
|
```
|
||||||
|
huggingface/<provider>/<hf_org_or_user>/<hf_model>
|
||||||
|
```
|
||||||
|
Where `<hf_org_or_user>/<hf_model>` is the Hugging Face model ID and `<provider>` is the inference provider.
|
||||||
|
By default, if you don't specify a provider, LiteLLM will use the [HF Inference API](https://huggingface.co/docs/api-inference/en/index).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Run DeepSeek-R1 inference through Together AI
|
||||||
|
completion(model="huggingface/together/deepseek-ai/DeepSeek-R1",...)
|
||||||
|
|
||||||
|
# Run Qwen2.5-72B-Instruct inference through Sambanova
|
||||||
|
completion(model="huggingface/sambanova/Qwen/Qwen2.5-72B-Instruct",...)
|
||||||
|
|
||||||
|
# Run Llama-3.3-70B-Instruct inference through HF Inference API
|
||||||
|
completion(model="huggingface/meta-llama/Llama-3.3-70B-Instruct",...)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
|
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
You need to tell LiteLLM when you're calling Huggingface.
|
### Basic Completion
|
||||||
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`.
|
Here's an example of chat completion using the DeepSeek-R1 model through Together AI:
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="serverless" label="Serverless Inference API">
|
|
||||||
|
|
||||||
By default, LiteLLM will assume a Hugging Face call follows the [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api), which is fully compatible with the OpenAI Chat Completion API.
|
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="sdk" label="SDK">
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
# [OPTIONAL] set env var
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
|
||||||
|
|
||||||
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
|
|
||||||
|
|
||||||
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
|
model="huggingface/together/deepseek-ai/DeepSeek-R1",
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How many r's are in the word 'strawberry'?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
Now, let's see what a streaming request looks like.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/together/deepseek-ai/DeepSeek-R1",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How many r's are in the word `strawberry`?",
|
||||||
|
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Image Input
|
||||||
|
You can also pass images when the model supports it. Here is an example using [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) model through Sambanova.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# Set your Hugging Face Token
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/sambanova/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(response.choices[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Function Calling
|
||||||
|
You can extend the model's capabilities by giving them access to tools. Here is an example with function calling using [Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) model through Sambanova.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# Set your Hugging Face Token
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like in Boston today?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/sambanova/meta-llama/Llama-3.3-70B-Instruct",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="endpoints" label="Inference Endpoints">
|
||||||
|
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
### Basic Completion
|
||||||
|
After you have [deployed your Hugging Face Inference Endpoint](https://endpoints.huggingface.co/new) on dedicated infrastructure, you can run inference on it by providing the endpoint base URL in `api_base`, and indicating `huggingface/tgi` as the model name.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/tgi",
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/tgi",
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="proxy" label="PROXY">
|
|
||||||
|
|
||||||
1. Add models to your config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: llama-3.1-8B-instruct
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct
|
|
||||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ litellm --config /path/to/config.yaml --debug
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test it!
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"model": "llama-3.1-8B-instruct",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I like you!"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="classification" label="Text Classification">
|
|
||||||
|
|
||||||
Append `text-classification` to the model name
|
|
||||||
|
|
||||||
e.g. `huggingface/text-classification/<model-name>`
|
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="sdk" label="SDK">
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
# [OPTIONAL] set env var
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
|
||||||
|
|
||||||
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
|
||||||
|
|
||||||
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
|
||||||
response = completion(
|
|
||||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
|
||||||
messages=messages,
|
|
||||||
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="proxy" label="PROXY">
|
|
||||||
|
|
||||||
1. Add models to your config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: bert-classifier
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
|
||||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
|
||||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ litellm --config /path/to/config.yaml --debug
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test it!
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"model": "bert-classifier",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I like you!"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="dedicated" label="Dedicated Inference Endpoints">
|
|
||||||
|
|
||||||
Steps to use
|
|
||||||
* Create your own Hugging Face dedicated endpoint here: https://ui.endpoints.huggingface.co/
|
|
||||||
* Set `api_base` to your deployed api base
|
|
||||||
* Add the `huggingface/` prefix to your model so litellm knows it's a huggingface Deployed Inference Endpoint
|
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="sdk" label="SDK">
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = ""
|
|
||||||
|
|
||||||
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
|
|
||||||
# add the 'huggingface/' prefix to the model to set huggingface as the provider
|
|
||||||
# set api base to your deployed api endpoint from hugging face
|
|
||||||
response = completion(
|
|
||||||
model="huggingface/glaiveai/glaive-coder-7b",
|
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
|
||||||
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="proxy" label="PROXY">
|
|
||||||
|
|
||||||
1. Add models to your config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: glaive-coder
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/glaiveai/glaive-coder-7b
|
|
||||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
|
||||||
api_base: "https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ litellm --config /path/to/config.yaml --debug
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test it!
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"model": "glaive-coder",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I like you!"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
## Streaming
|
|
||||||
|
|
||||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_HuggingFace.ipynb">
|
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
You need to tell LiteLLM when you're calling Huggingface.
|
|
||||||
This is done by adding the "huggingface/" prefix to `model`, example `completion(model="huggingface/<model_name>",...)`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
# [OPTIONAL] set env var
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
|
||||||
|
|
||||||
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
|
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
|
||||||
response = completion(
|
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
|
||||||
messages=messages,
|
|
||||||
api_base="https://my-endpoint.huggingface.cloud",
|
|
||||||
stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Image Input
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/tgi",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/""
|
||||||
|
)
|
||||||
|
print(response.choices[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Function Calling
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["HF_TOKEN"] = "hf_xxxxxx"
|
||||||
|
|
||||||
|
functions = [{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The location to get weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/tgi",
|
||||||
|
messages=[{"content": "What's the weather like in San Francisco?", "role": "user"}],
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud/v1/",
|
||||||
|
functions=functions
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## LiteLLM Proxy Server with Hugging Face models
|
||||||
|
You can set up a [LiteLLM Proxy Server](https://docs.litellm.ai/#litellm-proxy-server-llm-gateway) to serve Hugging Face models through any of the supported Inference Providers. Here's how to do it:
|
||||||
|
|
||||||
|
### Step 1. Setup the config file
|
||||||
|
|
||||||
|
In this case, we are configuring a proxy to serve `DeepSeek R1` from Hugging Face, using Together AI as the backend Inference Provider.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-r1-model
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/together/deepseek-ai/DeepSeek-R1
|
||||||
|
api_key: os.environ/HF_TOKEN # ensure you have `HF_TOKEN` in your .env
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2. Start the server
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3. Make a request to the server
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "my-r1-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, how are you?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="python" label="python">
|
||||||
|
|
||||||
|
```python
|
||||||
|
# pip install openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="anything",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="my-r1-model",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
## Embedding
|
## Embedding
|
||||||
|
|
||||||
LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) format.
|
LiteLLM supports Hugging Face's [text-embedding-inference](https://github.com/huggingface/text-embeddings-inference) models as well.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
import os
|
import os
|
||||||
os.environ['HUGGINGFACE_API_KEY'] = ""
|
os.environ['HF_TOKEN'] = "hf_xxxxxx"
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model='huggingface/microsoft/codebert-base',
|
model='huggingface/microsoft/codebert-base',
|
||||||
input=["good morning from litellm"]
|
input=["good morning from litellm"]
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced
|
|
||||||
|
|
||||||
### Setting API KEYS + API BASE
|
|
||||||
|
|
||||||
If required, you can set the api key + api base, set it in your os environment. [Code for how it's sent](https://github.com/BerriAI/litellm/blob/0100ab2382a0e720c7978fbf662cc6e6920e7e03/litellm/llms/huggingface_restapi.py#L25)
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = ""
|
|
||||||
os.environ["HUGGINGFACE_API_BASE"] = ""
|
|
||||||
```
|
|
||||||
|
|
||||||
### Viewing Log probs
|
|
||||||
|
|
||||||
#### Using `decoder_input_details` - OpenAI `echo`
|
|
||||||
|
|
||||||
The `echo` param is supported by OpenAI Completions - Use `litellm.text_completion()` for this
|
|
||||||
|
|
||||||
```python
|
|
||||||
from litellm import text_completion
|
|
||||||
response = text_completion(
|
|
||||||
model="huggingface/bigcode/starcoder",
|
|
||||||
prompt="good morning",
|
|
||||||
max_tokens=10, logprobs=10,
|
|
||||||
echo=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Output
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-3fc71792-c442-4ba1-a611-19dd0ac371ad",
|
|
||||||
"object": "text_completion",
|
|
||||||
"created": 1698801125.936519,
|
|
||||||
"model": "bigcode/starcoder",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"text": ", I'm going to make you a sand",
|
|
||||||
"index": 0,
|
|
||||||
"logprobs": {
|
|
||||||
"tokens": [
|
|
||||||
"good",
|
|
||||||
" morning",
|
|
||||||
",",
|
|
||||||
" I",
|
|
||||||
"'m",
|
|
||||||
" going",
|
|
||||||
" to",
|
|
||||||
" make",
|
|
||||||
" you",
|
|
||||||
" a",
|
|
||||||
" s",
|
|
||||||
"and"
|
|
||||||
],
|
|
||||||
"token_logprobs": [
|
|
||||||
"None",
|
|
||||||
-14.96875,
|
|
||||||
-2.2285156,
|
|
||||||
-2.734375,
|
|
||||||
-2.0957031,
|
|
||||||
-2.0917969,
|
|
||||||
-0.09429932,
|
|
||||||
-3.1132812,
|
|
||||||
-1.3203125,
|
|
||||||
-1.2304688,
|
|
||||||
-1.6201172,
|
|
||||||
-0.010292053
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"finish_reason": "length"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"completion_tokens": 9,
|
|
||||||
"prompt_tokens": 2,
|
|
||||||
"total_tokens": 11
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Models with Prompt Formatting
|
|
||||||
|
|
||||||
For models with special prompt templates (e.g. Llama2), we format the prompt to fit their template.
|
|
||||||
|
|
||||||
#### Models with natively Supported Prompt Templates
|
|
||||||
|
|
||||||
| Model Name | Works for Models | Function Call | Required OS Variables |
|
|
||||||
| ------------------------------------ | ---------------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ----------------------------------- |
|
|
||||||
| mistralai/Mistral-7B-Instruct-v0.1 | mistralai/Mistral-7B-Instruct-v0.1 | `completion(model='huggingface/mistralai/Mistral-7B-Instruct-v0.1', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| meta-llama/Llama-2-7b-chat | All meta-llama llama2 chat models | `completion(model='huggingface/meta-llama/Llama-2-7b', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| tiiuae/falcon-7b-instruct | All falcon instruct models | `completion(model='huggingface/tiiuae/falcon-7b-instruct', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| mosaicml/mpt-7b-chat | All mpt chat models | `completion(model='huggingface/mosaicml/mpt-7b-chat', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| codellama/CodeLlama-34b-Instruct-hf | All codellama instruct models | `completion(model='huggingface/codellama/CodeLlama-34b-Instruct-hf', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| WizardLM/WizardCoder-Python-34B-V1.0 | All wizardcoder models | `completion(model='huggingface/WizardLM/WizardCoder-Python-34B-V1.0', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
| Phind/Phind-CodeLlama-34B-v2 | All phind-codellama models | `completion(model='huggingface/Phind/Phind-CodeLlama-34B-v2', messages=messages, api_base="your_api_endpoint")` | `os.environ['HUGGINGFACE_API_KEY']` |
|
|
||||||
|
|
||||||
**What if we don't support a model you need?**
|
|
||||||
You can also specify you're own custom prompt formatting, in case we don't have your model covered yet.
|
|
||||||
|
|
||||||
**Does this mean you have to specify a prompt for all models?**
|
|
||||||
No. By default we'll concatenate your message content to make a prompt.
|
|
||||||
|
|
||||||
**Default Prompt Template**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def default_pt(messages):
|
|
||||||
return " ".join(message["content"] for message in messages)
|
|
||||||
```
|
|
||||||
|
|
||||||
[Code for how prompt formats work in LiteLLM](https://github.com/BerriAI/litellm/blob/main/litellm/llms/prompt_templates/factory.py)
|
|
||||||
|
|
||||||
#### Custom prompt templates
|
|
||||||
|
|
||||||
```python
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
# Create your own custom prompt template works
|
|
||||||
litellm.register_prompt_template(
|
|
||||||
model="togethercomputer/LLaMA-2-7B-32K",
|
|
||||||
roles={
|
|
||||||
"system": {
|
|
||||||
"pre_message": "[INST] <<SYS>>\n",
|
|
||||||
"post_message": "\n<</SYS>>\n [/INST]\n"
|
|
||||||
},
|
|
||||||
"user": {
|
|
||||||
"pre_message": "[INST] ",
|
|
||||||
"post_message": " [/INST]\n"
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"post_message": "\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_huggingface_custom_model():
|
|
||||||
model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
|
|
||||||
response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud")
|
|
||||||
print(response['choices'][0]['message']['content'])
|
|
||||||
return response
|
|
||||||
|
|
||||||
test_huggingface_custom_model()
|
|
||||||
```
|
|
||||||
|
|
||||||
[Implementation Code](https://github.com/BerriAI/litellm/blob/c0b3da2c14c791a0b755f0b1e5a9ef065951ecbf/litellm/llms/huggingface_restapi.py#L52)
|
|
||||||
|
|
||||||
### Deploying a model on huggingface
|
|
||||||
|
|
||||||
You can use any chat/text model from Hugging Face with the following steps:
|
|
||||||
|
|
||||||
- Copy your model id/url from Huggingface Inference Endpoints
|
|
||||||
- [ ] Go to https://ui.endpoints.huggingface.co/
|
|
||||||
- [ ] Copy the url of the specific model you'd like to use
|
|
||||||
<Image img={require('../../img/hf_inference_endpoint.png')} alt="HF_Dashboard" style={{ maxWidth: '50%', height: 'auto' }}/>
|
|
||||||
- Set it as your model name
|
|
||||||
- Set your HUGGINGFACE_API_KEY as an environment variable
|
|
||||||
|
|
||||||
Need help deploying a model on huggingface? [Check out this guide.](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
|
|
||||||
|
|
||||||
# output
|
|
||||||
|
|
||||||
Same as the OpenAI format, but also includes logprobs. [See the code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/llms/huggingface_restapi.py#L115)
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"content": "\ud83d\ude31\n\nComment: @SarahSzabo I'm",
|
|
||||||
"role": "assistant",
|
|
||||||
"logprobs": -22.697942825499993
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"created": 1693436637.38206,
|
|
||||||
"model": "https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 14,
|
|
||||||
"completion_tokens": 11,
|
|
||||||
"total_tokens": 25
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
# FAQ
|
# FAQ
|
||||||
|
|
||||||
**Does this support stop sequences?**
|
**How does billing work with Hugging Face Inference Providers?**
|
||||||
|
|
||||||
Yes, we support stop sequences - and you can pass as many as allowed by Hugging Face (or any provider!)
|
> Billing is centralized on your Hugging Face account, no matter which providers you are using. You are billed the standard provider API rates with no additional markup - Hugging Face simply passes through the provider costs. Note that [Hugging Face PRO](https://huggingface.co/subscribe/pro) users get $2 worth of Inference credits every month that can be used across providers.
|
||||||
|
|
||||||
**How do you deal with repetition penalty?**
|
**Do I need to create an account for each Inference Provider?**
|
||||||
|
|
||||||
We map the presence penalty parameter in openai to the repetition penalty parameter on Hugging Face. [See code](https://github.com/BerriAI/litellm/blob/b4b2dbf005142e0a483d46a07a88a19814899403/litellm/utils.py#L757).
|
> No, you don't need to create separate accounts. All requests are routed through Hugging Face, so you only need your HF token. This allows you to easily benchmark different providers and choose the one that best fits your needs.
|
||||||
|
|
||||||
|
**Will more inference providers be supported by Hugging Face in the future?**
|
||||||
|
|
||||||
|
> Yes! New inference providers (and models) are being added gradually.
|
||||||
|
|
||||||
We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)!
|
We welcome any suggestions for improving our Hugging Face integration - Create an [issue](https://github.com/BerriAI/litellm/issues/new/choose)/[Join the Discord](https://discord.com/invite/wuPM9dRgDw)!
|
|
@ -156,7 +156,7 @@ PROXY_LOGOUT_URL="https://www.google.com"
|
||||||
|
|
||||||
Set this in your .env (so the proxy can set the correct redirect url)
|
Set this in your .env (so the proxy can set the correct redirect url)
|
||||||
```shell
|
```shell
|
||||||
PROXY_BASE_URL=https://litellm-api.up.railway.app/
|
PROXY_BASE_URL=https://litellm-api.up.railway.app
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 4. Test flow
|
#### Step 4. Test flow
|
||||||
|
|
|
@ -406,6 +406,7 @@ router_settings:
|
||||||
| HELICONE_API_KEY | API key for Helicone service
|
| HELICONE_API_KEY | API key for Helicone service
|
||||||
| HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog)
|
| HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog)
|
||||||
| HUGGINGFACE_API_BASE | Base URL for Hugging Face API
|
| HUGGINGFACE_API_BASE | Base URL for Hugging Face API
|
||||||
|
| HUGGINGFACE_API_KEY | API key for Hugging Face API
|
||||||
| IAM_TOKEN_DB_AUTH | IAM token for database authentication
|
| IAM_TOKEN_DB_AUTH | IAM token for database authentication
|
||||||
| JSON_LOGS | Enable JSON formatted logging
|
| JSON_LOGS | Enable JSON formatted logging
|
||||||
| JWT_AUDIENCE | Expected audience for JWT tokens
|
| JWT_AUDIENCE | Expected audience for JWT tokens
|
||||||
|
|
|
@ -6,6 +6,8 @@ import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
Track spend for keys, users, and teams across 100+ LLMs.
|
Track spend for keys, users, and teams across 100+ LLMs.
|
||||||
|
|
||||||
|
LiteLLM automatically tracks spend for all known models. See our [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||||
|
|
||||||
### How to Track Spend with LiteLLM
|
### How to Track Spend with LiteLLM
|
||||||
|
|
||||||
**Step 1**
|
**Step 1**
|
||||||
|
@ -35,10 +37,10 @@ response = client.chat.completions.create(
|
||||||
"content": "this is a test request, write a short poem"
|
"content": "this is a test request, write a short poem"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
user="palantir",
|
user="palantir", # OPTIONAL: pass user to track spend by user
|
||||||
extra_body={
|
extra_body={
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"]
|
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -63,9 +65,9 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
"content": "what llm are you"
|
"content": "what llm are you"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"user": "palantir",
|
"user": "palantir", # OPTIONAL: pass user to track spend by user
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"]
|
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
@ -90,7 +92,7 @@ chat = ChatOpenAI(
|
||||||
user="palantir",
|
user="palantir",
|
||||||
extra_body={
|
extra_body={
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"]
|
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] # ENTERPRISE: pass tags to track spend by tags
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -150,8 +152,134 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
## ✨ (Enterprise) API Endpoints to get Spend
|
### Allowing Non-Proxy Admins to access `/spend` endpoints
|
||||||
### Getting Spend Reports - To Charge Other Teams, Customers, Users
|
|
||||||
|
Use this when you want non-proxy admins to access `/spend` endpoints
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Schedule a [meeting with us to get your Enterprise License](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
##### Create Key
|
||||||
|
Create Key with with `permissions={"get_spend_routes": true}`
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"permissions": {"get_spend_routes": true}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Use generated key on `/spend` endpoints
|
||||||
|
|
||||||
|
Access spend Routes with newly generate keys
|
||||||
|
```shell
|
||||||
|
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
|
||||||
|
-H 'Authorization: Bearer sk-H16BKvrSNConSsBYLGc_7A'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### Reset Team, API Key Spend - MASTER KEY ONLY
|
||||||
|
|
||||||
|
Use `/global/spend/reset` if you want to:
|
||||||
|
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
|
||||||
|
|
||||||
|
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
Only the `LITELLM_MASTER_KEY` you set can access this route
|
||||||
|
```shell
|
||||||
|
curl -X POST \
|
||||||
|
'http://localhost:4000/global/spend/reset' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Expected Responses
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Set 'base_model' for Cost Tracking (e.g. Azure deployments)
|
||||||
|
|
||||||
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
|
||||||
|
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
|
||||||
|
|
||||||
|
Get the base model name from [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||||
|
|
||||||
|
Example config with `base_model`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-gpt-3.5
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
model_info:
|
||||||
|
base_model: azure/gpt-4-1106-preview
|
||||||
|
```
|
||||||
|
|
||||||
|
## Daily Spend Breakdown API
|
||||||
|
|
||||||
|
Retrieve granular daily usage data for a user (by model, provider, and API key) with a single endpoint.
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
|
||||||
|
```shell title="Daily Spend Breakdown API" showLineNumbers
|
||||||
|
curl -L -X GET 'http://localhost:4000/user/daily/activity?start_date=2025-03-20&end_date=2025-03-27' \
|
||||||
|
-H 'Authorization: Bearer sk-...'
|
||||||
|
```
|
||||||
|
|
||||||
|
```json title="Daily Spend Breakdown API Response" showLineNumbers
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"date": "2025-03-27",
|
||||||
|
"metrics": {
|
||||||
|
"spend": 0.0177072,
|
||||||
|
"prompt_tokens": 111,
|
||||||
|
"completion_tokens": 1711,
|
||||||
|
"total_tokens": 1822,
|
||||||
|
"api_requests": 11
|
||||||
|
},
|
||||||
|
"breakdown": {
|
||||||
|
"models": {
|
||||||
|
"gpt-4o-mini": {
|
||||||
|
"spend": 1.095e-05,
|
||||||
|
"prompt_tokens": 37,
|
||||||
|
"completion_tokens": 9,
|
||||||
|
"total_tokens": 46,
|
||||||
|
"api_requests": 1
|
||||||
|
},
|
||||||
|
"providers": { "openai": { ... }, "azure_ai": { ... } },
|
||||||
|
"api_keys": { "3126b6eaf1...": { ... } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"total_spend": 0.7274667,
|
||||||
|
"total_prompt_tokens": 280990,
|
||||||
|
"total_completion_tokens": 376674,
|
||||||
|
"total_api_requests": 14
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Reference
|
||||||
|
|
||||||
|
See our [Swagger API](https://litellm-api.up.railway.app/#/Budget%20%26%20Spend%20Tracking/get_user_daily_activity_user_daily_activity_get) for more details on the `/user/daily/activity` endpoint
|
||||||
|
|
||||||
|
## ✨ (Enterprise) Generate Spend Reports
|
||||||
|
|
||||||
|
Use this to charge other teams, customers, users
|
||||||
|
|
||||||
Use the `/global/spend/report` endpoint to get spend reports
|
Use the `/global/spend/report` endpoint to get spend reports
|
||||||
|
|
||||||
|
@ -470,105 +598,6 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### Allowing Non-Proxy Admins to access `/spend` endpoints
|
|
||||||
|
|
||||||
Use this when you want non-proxy admins to access `/spend` endpoints
|
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
Schedule a [meeting with us to get your Enterprise License](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
##### Create Key
|
|
||||||
Create Key with with `permissions={"get_spend_routes": true}`
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"permissions": {"get_spend_routes": true}
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Use generated key on `/spend` endpoints
|
|
||||||
|
|
||||||
Access spend Routes with newly generate keys
|
|
||||||
```shell
|
|
||||||
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
|
|
||||||
-H 'Authorization: Bearer sk-H16BKvrSNConSsBYLGc_7A'
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### Reset Team, API Key Spend - MASTER KEY ONLY
|
|
||||||
|
|
||||||
Use `/global/spend/reset` if you want to:
|
|
||||||
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
|
|
||||||
|
|
||||||
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
|
|
||||||
|
|
||||||
##### Request
|
|
||||||
Only the `LITELLM_MASTER_KEY` you set can access this route
|
|
||||||
```shell
|
|
||||||
curl -X POST \
|
|
||||||
'http://localhost:4000/global/spend/reset' \
|
|
||||||
-H 'Authorization: Bearer sk-1234' \
|
|
||||||
-H 'Content-Type: application/json'
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Expected Responses
|
|
||||||
|
|
||||||
```shell
|
|
||||||
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Spend Tracking for Azure OpenAI Models
|
|
||||||
|
|
||||||
Set base model for cost tracking azure image-gen call
|
|
||||||
|
|
||||||
#### Image Generation
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: dall-e-3
|
|
||||||
litellm_params:
|
|
||||||
model: azure/dall-e-3-test
|
|
||||||
api_version: 2023-06-01-preview
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
base_model: dall-e-3 # 👈 set dall-e-3 as base model
|
|
||||||
model_info:
|
|
||||||
mode: image_generation
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Chat Completions / Embeddings
|
|
||||||
|
|
||||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
|
||||||
|
|
||||||
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
|
|
||||||
|
|
||||||
Get the base model name from [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
|
||||||
|
|
||||||
Example config with `base_model`
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: azure-gpt-3.5
|
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_version: "2023-07-01-preview"
|
|
||||||
model_info:
|
|
||||||
base_model: azure/gpt-4-1106-preview
|
|
||||||
```
|
|
||||||
|
|
||||||
## Custom Input/Output Pricing
|
|
||||||
|
|
||||||
👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models
|
|
||||||
|
|
||||||
## ✨ Custom Spend Log metadata
|
## ✨ Custom Spend Log metadata
|
||||||
|
|
||||||
|
@ -588,3 +617,4 @@ Logging specific key,value pairs in spend logs metadata is an enterprise feature
|
||||||
Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags)
|
Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags)
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
86
docs/my-website/docs/proxy/db_deadlocks.md
Normal file
86
docs/my-website/docs/proxy/db_deadlocks.md
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# High Availability Setup (Resolve DB Deadlocks)
|
||||||
|
|
||||||
|
Resolve any Database Deadlocks you see in high traffic by using this setup
|
||||||
|
|
||||||
|
## What causes the problem?
|
||||||
|
|
||||||
|
LiteLLM writes `UPDATE` and `UPSERT` queries to the DB. When using 10+ instances of LiteLLM, these queries can cause deadlocks since each instance could simultaneously attempt to update the same `user_id`, `team_id`, `key` etc.
|
||||||
|
|
||||||
|
## How the high availability setup fixes the problem
|
||||||
|
- All instances will write to a Redis queue instead of the DB.
|
||||||
|
- A single instance will acquire a lock on the DB and flush the redis queue to the DB.
|
||||||
|
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
### Stage 1. Each instance writes updates to redis
|
||||||
|
|
||||||
|
Each instance will accumlate the spend updates for a key, user, team, etc and write the updates to a redis queue.
|
||||||
|
|
||||||
|
<Image img={require('../../img/deadlock_fix_1.png')} style={{ width: '900px', height: 'auto' }} />
|
||||||
|
<p style={{textAlign: 'left', color: '#666'}}>
|
||||||
|
Each instance writes updates to redis
|
||||||
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
### Stage 2. A single instance flushes the redis queue to the DB
|
||||||
|
|
||||||
|
A single instance will acquire a lock on the DB and flush all elements in the redis queue to the DB.
|
||||||
|
|
||||||
|
- 1 instance will attempt to acquire the lock for the DB update job
|
||||||
|
- The status of the lock is stored in redis
|
||||||
|
- If the instance acquires the lock to write to DB
|
||||||
|
- It will read all updates from redis
|
||||||
|
- Aggregate all updates into 1 transaction
|
||||||
|
- Write updates to DB
|
||||||
|
- Release the lock
|
||||||
|
- Note: Only 1 instance can acquire the lock at a time, this limits the number of instances that can write to the DB at once
|
||||||
|
|
||||||
|
|
||||||
|
<Image img={require('../../img/deadlock_fix_2.png')} style={{ width: '900px', height: 'auto' }} />
|
||||||
|
<p style={{textAlign: 'left', color: '#666'}}>
|
||||||
|
A single instance flushes the redis queue to the DB
|
||||||
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Required components
|
||||||
|
|
||||||
|
- Redis
|
||||||
|
- Postgres
|
||||||
|
|
||||||
|
### Setup on LiteLLM config
|
||||||
|
|
||||||
|
You can enable using the redis buffer by setting `use_redis_transaction_buffer: true` in the `general_settings` section of your `proxy_config.yaml` file.
|
||||||
|
|
||||||
|
Note: This setup requires litellm to be connected to a redis instance.
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="litellm proxy_config.yaml"
|
||||||
|
general_settings:
|
||||||
|
use_redis_transaction_buffer: true
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
cache: True
|
||||||
|
cache_params:
|
||||||
|
type: redis
|
||||||
|
supported_call_types: [] # Optional: Set cache for proxy, but not on the actual llm api call
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring
|
||||||
|
|
||||||
|
LiteLLM emits the following prometheus metrics to monitor the health/status of the in memory buffer and redis buffer.
|
||||||
|
|
||||||
|
|
||||||
|
| Metric Name | Description | Storage Type |
|
||||||
|
|-----------------------------------------------------|-----------------------------------------------------------------------------|--------------|
|
||||||
|
| `litellm_pod_lock_manager_size` | Indicates which pod has the lock to write updates to the database. | Redis |
|
||||||
|
| `litellm_in_memory_daily_spend_update_queue_size` | Number of items in the in-memory daily spend update queue. These are the aggregate spend logs for each user. | In-Memory |
|
||||||
|
| `litellm_redis_daily_spend_update_queue_size` | Number of items in the Redis daily spend update queue. These are the aggregate spend logs for each user. | Redis |
|
||||||
|
| `litellm_in_memory_spend_update_queue_size` | In-memory aggregate spend values for keys, users, teams, team members, etc.| In-Memory |
|
||||||
|
| `litellm_redis_spend_update_queue_size` | Redis aggregate spend values for keys, users, teams, etc. | Redis |
|
||||||
|
|
|
@ -140,7 +140,7 @@ The above request should not be blocked, and you should receive a regular LLM re
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
# Advanced
|
## Advanced
|
||||||
|
|
||||||
Aim Guard provides user-specific Guardrail policies, enabling you to apply tailored policies to individual users.
|
Aim Guard provides user-specific Guardrail policies, enabling you to apply tailored policies to individual users.
|
||||||
To utilize this feature, include the end-user's email in the request payload by setting the `x-aim-user-email` header of your request.
|
To utilize this feature, include the end-user's email in the request payload by setting the `x-aim-user-email` header of your request.
|
||||||
|
|
|
@ -177,6 +177,50 @@ export LITELLM_SALT_KEY="sk-1234"
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/036a6821d588bd36d170713dcf5a72791a694178/litellm/proxy/common_utils/encrypt_decrypt_utils.py#L15)
|
[**See Code**](https://github.com/BerriAI/litellm/blob/036a6821d588bd36d170713dcf5a72791a694178/litellm/proxy/common_utils/encrypt_decrypt_utils.py#L15)
|
||||||
|
|
||||||
|
|
||||||
|
## 9. Use `prisma migrate deploy`
|
||||||
|
|
||||||
|
Use this to handle db migrations across LiteLLM versions in production
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="env" label="ENV">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
USE_PRISMA_MIGRATE="True"
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="cli" label="CLI">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --use_prisma_migrate
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
|
||||||
|
The migrate deploy command:
|
||||||
|
|
||||||
|
- **Does not** issue a warning if an already applied migration is missing from migration history
|
||||||
|
- **Does not** detect drift (production database schema differs from migration history end state - for example, due to a hotfix)
|
||||||
|
- **Does not** reset the database or generate artifacts (such as Prisma Client)
|
||||||
|
- **Does not** rely on a shadow database
|
||||||
|
|
||||||
|
|
||||||
|
### How does LiteLLM handle DB migrations in production?
|
||||||
|
|
||||||
|
1. A new migration file is written to our `litellm-proxy-extras` package. [See all](https://github.com/BerriAI/litellm/tree/main/litellm-proxy-extras/litellm_proxy_extras/migrations)
|
||||||
|
|
||||||
|
2. The core litellm pip package is bumped to point to the new `litellm-proxy-extras` package. This ensures, older versions of LiteLLM will continue to use the old migrations. [See code](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/pyproject.toml#L58)
|
||||||
|
|
||||||
|
3. When you upgrade to a new version of LiteLLM, the migration file is applied to the database. [See code](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/litellm-proxy-extras/litellm_proxy_extras/utils.py#L42)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Extras
|
## Extras
|
||||||
### Expected Performance in Production
|
### Expected Performance in Production
|
||||||
|
|
||||||
|
|
|
@ -242,6 +242,19 @@ litellm_settings:
|
||||||
| `litellm_redis_fails` | Number of failed redis calls |
|
| `litellm_redis_fails` | Number of failed redis calls |
|
||||||
| `litellm_self_latency` | Histogram latency for successful litellm api call |
|
| `litellm_self_latency` | Histogram latency for successful litellm api call |
|
||||||
|
|
||||||
|
#### DB Transaction Queue Health Metrics
|
||||||
|
|
||||||
|
Use these metrics to monitor the health of the DB Transaction Queue. Eg. Monitoring the size of the in-memory and redis buffers.
|
||||||
|
|
||||||
|
| Metric Name | Description | Storage Type |
|
||||||
|
|-----------------------------------------------------|-----------------------------------------------------------------------------|--------------|
|
||||||
|
| `litellm_pod_lock_manager_size` | Indicates which pod has the lock to write updates to the database. | Redis |
|
||||||
|
| `litellm_in_memory_daily_spend_update_queue_size` | Number of items in the in-memory daily spend update queue. These are the aggregate spend logs for each user. | In-Memory |
|
||||||
|
| `litellm_redis_daily_spend_update_queue_size` | Number of items in the Redis daily spend update queue. These are the aggregate spend logs for each user. | Redis |
|
||||||
|
| `litellm_in_memory_spend_update_queue_size` | In-memory aggregate spend values for keys, users, teams, team members, etc.| In-Memory |
|
||||||
|
| `litellm_redis_spend_update_queue_size` | Redis aggregate spend values for keys, users, teams, etc. | Redis |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## **🔥 LiteLLM Maintained Grafana Dashboards **
|
## **🔥 LiteLLM Maintained Grafana Dashboards **
|
||||||
|
|
||||||
|
@ -268,6 +281,17 @@ Here is a screenshot of the metrics you can monitor with the LiteLLM Grafana Das
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Add authentication on /metrics endpoint
|
||||||
|
|
||||||
|
**By default /metrics endpoint is unauthenticated.**
|
||||||
|
|
||||||
|
You can opt into running litellm authentication on the /metrics endpoint by setting the following on the config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
require_auth_for_metrics_endpoint: true
|
||||||
|
```
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
### What are `_created` vs. `_total` metrics?
|
### What are `_created` vs. `_total` metrics?
|
||||||
|
|
BIN
docs/my-website/img/deadlock_fix_1.png
Normal file
BIN
docs/my-website/img/deadlock_fix_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 60 KiB |
BIN
docs/my-website/img/deadlock_fix_2.png
Normal file
BIN
docs/my-website/img/deadlock_fix_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 70 KiB |
BIN
docs/my-website/img/hf_filter_inference_providers.png
Normal file
BIN
docs/my-website/img/hf_filter_inference_providers.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 120 KiB |
BIN
docs/my-website/img/prevent_deadlocks.jpg
Normal file
BIN
docs/my-website/img/prevent_deadlocks.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 325 KiB |
BIN
docs/my-website/img/release_notes/new_activity_tab.png
Normal file
BIN
docs/my-website/img/release_notes/new_activity_tab.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 326 KiB |
BIN
docs/my-website/img/release_notes/spend_by_model.jpg
Normal file
BIN
docs/my-website/img/release_notes/spend_by_model.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 488 KiB |
7
docs/my-website/package-lock.json
generated
7
docs/my-website/package-lock.json
generated
|
@ -12559,9 +12559,10 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/image-size": {
|
"node_modules/image-size": {
|
||||||
"version": "1.1.1",
|
"version": "1.2.1",
|
||||||
"resolved": "https://registry.npmjs.org/image-size/-/image-size-1.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz",
|
||||||
"integrity": "sha512-541xKlUw6jr/6gGuk92F+mYM5zaFAc5ahphvkqvNe2bQ6gVBkd6bfrmVJ2t4KDAfikAYZyIqTnktX3i6/aQDrQ==",
|
"integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==",
|
||||||
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"queue": "6.0.2"
|
"queue": "6.0.2"
|
||||||
},
|
},
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
|
@ -6,7 +6,7 @@ authors:
|
||||||
- name: Krrish Dholakia
|
- name: Krrish Dholakia
|
||||||
title: CEO, LiteLLM
|
title: CEO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/krish-d/
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
- name: Ishaan Jaffer
|
- name: Ishaan Jaffer
|
||||||
title: CTO, LiteLLM
|
title: CTO, LiteLLM
|
||||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
|
176
docs/my-website/release_notes/v1.65.4-stable/index.md
Normal file
176
docs/my-website/release_notes/v1.65.4-stable/index.md
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
---
|
||||||
|
title: v1.65.4-stable
|
||||||
|
slug: v1.65.4-stable
|
||||||
|
date: 2025-04-05T10:00:00
|
||||||
|
authors:
|
||||||
|
- name: Krrish Dholakia
|
||||||
|
title: CEO, LiteLLM
|
||||||
|
url: https://www.linkedin.com/in/krish-d/
|
||||||
|
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1749686400&v=beta&t=Hkl3U8Ps0VtvNxX0BNNq24b4dtX5wQaPFp6oiKCIHD8
|
||||||
|
- name: Ishaan Jaffer
|
||||||
|
title: CTO, LiteLLM
|
||||||
|
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||||
|
image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg
|
||||||
|
|
||||||
|
tags: []
|
||||||
|
hide_table_of_contents: false
|
||||||
|
---
|
||||||
|
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
## Deploy this version
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="docker" label="Docker">
|
||||||
|
|
||||||
|
``` showLineNumbers title="docker run litellm"
|
||||||
|
docker run
|
||||||
|
-e STORE_MODEL_IN_DB=True
|
||||||
|
-p 4000:4000
|
||||||
|
ghcr.io/berriai/litellm:main-v1.65.4-stable
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="pip" label="Pip">
|
||||||
|
|
||||||
|
``` showLineNumbers title="pip install litellm"
|
||||||
|
pip install litellm==1.65.4.post1
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
v1.65.4-stable is live. Here are the improvements since v1.65.0-stable.
|
||||||
|
|
||||||
|
## Key Highlights
|
||||||
|
- **Preventing DB Deadlocks**: Fixes a high-traffic issue when multiple instances were writing to the DB at the same time.
|
||||||
|
- **New Usage Tab**: Enables viewing spend by model and customizing date range
|
||||||
|
|
||||||
|
Let's dive in.
|
||||||
|
|
||||||
|
### Preventing DB Deadlocks
|
||||||
|
|
||||||
|
<Image img={require('../../img/prevent_deadlocks.jpg')} />
|
||||||
|
|
||||||
|
This release fixes the DB deadlocking issue that users faced in high traffic (10K+ RPS). This is great because it enables user/key/team spend tracking works at that scale.
|
||||||
|
|
||||||
|
Read more about the new architecture [here](https://docs.litellm.ai/docs/proxy/db_deadlocks)
|
||||||
|
|
||||||
|
|
||||||
|
### New Usage Tab
|
||||||
|
|
||||||
|
<Image img={require('../../img/release_notes/spend_by_model.jpg')} />
|
||||||
|
|
||||||
|
The new Usage tab now brings the ability to track daily spend by model. This makes it easier to catch any spend tracking or token counting errors, when combined with the ability to view successful requests, and token usage.
|
||||||
|
|
||||||
|
To test this out, just go to Experimental > New Usage > Activity.
|
||||||
|
|
||||||
|
|
||||||
|
## New Models / Updated Models
|
||||||
|
|
||||||
|
1. Databricks - claude-3-7-sonnet cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L10350)
|
||||||
|
2. VertexAI - `gemini-2.5-pro-exp-03-25` cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4492)
|
||||||
|
3. VertexAI - `gemini-2.0-flash` cost tracking [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4689)
|
||||||
|
4. Groq - add whisper ASR models to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L3324)
|
||||||
|
5. IBM - Add watsonx/ibm/granite-3-8b-instruct to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L91)
|
||||||
|
6. Google AI Studio - add gemini/gemini-2.5-pro-preview-03-25 to model cost map [PR](https://github.com/BerriAI/litellm/blob/52b35cd8093b9ad833987b24f494586a1e923209/model_prices_and_context_window.json#L4850)
|
||||||
|
|
||||||
|
## LLM Translation
|
||||||
|
1. Vertex AI - Support anyOf param for OpenAI json schema translation [Get Started](https://docs.litellm.ai/docs/providers/vertex#json-schema)
|
||||||
|
2. Anthropic- response_format + thinking param support (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/reasoning_content)
|
||||||
|
3. Anthropic - if thinking token is specified and max tokens is not - ensure max token to anthropic is higher than thinking tokens (works across Anthropic API, Bedrock, Vertex) [PR](https://github.com/BerriAI/litellm/pull/9594)
|
||||||
|
4. Bedrock - latency optimized inference support [Get Started](https://docs.litellm.ai/docs/providers/bedrock#usage---latency-optimized-inference)
|
||||||
|
5. Sagemaker - handle special tokens + multibyte character code in response [Get Started](https://docs.litellm.ai/docs/providers/aws_sagemaker)
|
||||||
|
6. MCP - add support for using SSE MCP servers [Get Started](https://docs.litellm.ai/docs/mcp#usage)
|
||||||
|
8. Anthropic - new `litellm.messages.create` interface for calling Anthropic `/v1/messages` via passthrough [Get Started](https://docs.litellm.ai/docs/anthropic_unified#usage)
|
||||||
|
11. Anthropic - support ‘file’ content type in message param (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/providers/anthropic#usage---pdf)
|
||||||
|
12. Anthropic - map openai 'reasoning_effort' to anthropic 'thinking' param (works across Anthropic API, Bedrock, Vertex) [Get Started](https://docs.litellm.ai/docs/providers/anthropic#usage---thinking--reasoning_content)
|
||||||
|
13. Google AI Studio (Gemini) - [BETA] `/v1/files` upload support [Get Started](../../docs/providers/google_ai_studio/files)
|
||||||
|
14. Azure - fix o-series tool calling [Get Started](../../docs/providers/azure#tool-calling--function-calling)
|
||||||
|
15. Unified file id - [ALPHA] allow calling multiple providers with same file id [PR](https://github.com/BerriAI/litellm/pull/9718)
|
||||||
|
- This is experimental, and not recommended for production use.
|
||||||
|
- We plan to have a production-ready implementation by next week.
|
||||||
|
16. Google AI Studio (Gemini) - return logprobs [PR](https://github.com/BerriAI/litellm/pull/9713)
|
||||||
|
17. Anthropic - Support prompt caching for Anthropic tool calls [Get Started](https://docs.litellm.ai/docs/completion/prompt_caching)
|
||||||
|
18. OpenRouter - unwrap extra body on open router calls [PR](https://github.com/BerriAI/litellm/pull/9747)
|
||||||
|
19. VertexAI - fix credential caching issue [PR](https://github.com/BerriAI/litellm/pull/9756)
|
||||||
|
20. XAI - filter out 'name' param for XAI [PR](https://github.com/BerriAI/litellm/pull/9761)
|
||||||
|
21. Gemini - image generation output support [Get Started](../../docs/providers/gemini#image-generation)
|
||||||
|
22. Databricks - support claude-3-7-sonnet w/ thinking + response_format [Get Started](../../docs/providers/databricks#usage---thinking--reasoning_content)
|
||||||
|
|
||||||
|
## Spend Tracking Improvements
|
||||||
|
1. Reliability fix - Check sent and received model for cost calculation [PR](https://github.com/BerriAI/litellm/pull/9669)
|
||||||
|
2. Vertex AI - Multimodal embedding cost tracking [Get Started](https://docs.litellm.ai/docs/providers/vertex#multi-modal-embeddings), [PR](https://github.com/BerriAI/litellm/pull/9623)
|
||||||
|
|
||||||
|
## Management Endpoints / UI
|
||||||
|
|
||||||
|
<Image img={require('../../img/release_notes/new_activity_tab.png')} />
|
||||||
|
|
||||||
|
1. New Usage Tab
|
||||||
|
- Report 'total_tokens' + report success/failure calls
|
||||||
|
- Remove double bars on scroll
|
||||||
|
- Ensure ‘daily spend’ chart ordered from earliest to latest date
|
||||||
|
- showing spend per model per day
|
||||||
|
- show key alias on usage tab
|
||||||
|
- Allow non-admins to view their activity
|
||||||
|
- Add date picker to new usage tab
|
||||||
|
2. Virtual Keys Tab
|
||||||
|
- remove 'default key' on user signup
|
||||||
|
- fix showing user models available for personal key creation
|
||||||
|
3. Test Key Tab
|
||||||
|
- Allow testing image generation models
|
||||||
|
4. Models Tab
|
||||||
|
- Fix bulk adding models
|
||||||
|
- support reusable credentials for passthrough endpoints
|
||||||
|
- Allow team members to see team models
|
||||||
|
5. Teams Tab
|
||||||
|
- Fix json serialization error on update team metadata
|
||||||
|
6. Request Logs Tab
|
||||||
|
- Add reasoning_content token tracking across all providers on streaming
|
||||||
|
7. API
|
||||||
|
- return key alias on /user/daily/activity [Get Started](../../docs/proxy/cost_tracking#daily-spend-breakdown-api)
|
||||||
|
8. SSO
|
||||||
|
- Allow assigning SSO users to teams on MSFT SSO [PR](https://github.com/BerriAI/litellm/pull/9745)
|
||||||
|
|
||||||
|
## Logging / Guardrail Integrations
|
||||||
|
|
||||||
|
1. Console Logs - Add json formatting for uncaught exceptions [PR](https://github.com/BerriAI/litellm/pull/9619)
|
||||||
|
2. Guardrails - AIM Guardrails support for virtual key based policies [Get Started](../../docs/proxy/guardrails/aim_security)
|
||||||
|
3. Logging - fix completion start time tracking [PR](https://github.com/BerriAI/litellm/pull/9688)
|
||||||
|
4. Prometheus
|
||||||
|
- Allow adding authentication on Prometheus /metrics endpoints [PR](https://github.com/BerriAI/litellm/pull/9766)
|
||||||
|
- Distinguish LLM Provider Exception vs. LiteLLM Exception in metric naming [PR](https://github.com/BerriAI/litellm/pull/9760)
|
||||||
|
- Emit operational metrics for new DB Transaction architecture [PR](https://github.com/BerriAI/litellm/pull/9719)
|
||||||
|
|
||||||
|
## Performance / Loadbalancing / Reliability improvements
|
||||||
|
1. Preventing Deadlocks
|
||||||
|
- Reduce DB Deadlocks by storing spend updates in Redis and then committing to DB [PR](https://github.com/BerriAI/litellm/pull/9608)
|
||||||
|
- Ensure no deadlocks occur when updating DailyUserSpendTransaction [PR](https://github.com/BerriAI/litellm/pull/9690)
|
||||||
|
- High Traffic fix - ensure new DB + Redis architecture accurately tracks spend [PR](https://github.com/BerriAI/litellm/pull/9673)
|
||||||
|
- Use Redis for PodLock Manager instead of PG (ensures no deadlocks occur) [PR](https://github.com/BerriAI/litellm/pull/9715)
|
||||||
|
- v2 DB Deadlock Reduction Architecture – Add Max Size for In-Memory Queue + Backpressure Mechanism [PR](https://github.com/BerriAI/litellm/pull/9759)
|
||||||
|
|
||||||
|
2. Prisma Migrations [Get Started](../../docs/proxy/prod#9-use-prisma-migrate-deploy)
|
||||||
|
- connects litellm proxy to litellm's prisma migration files
|
||||||
|
- Handle db schema updates from new `litellm-proxy-extras` sdk
|
||||||
|
3. Redis - support password for sync sentinel clients [PR](https://github.com/BerriAI/litellm/pull/9622)
|
||||||
|
4. Fix "Circular reference detected" error when max_parallel_requests = 0 [PR](https://github.com/BerriAI/litellm/pull/9671)
|
||||||
|
5. Code QA - Ban hardcoded numbers [PR](https://github.com/BerriAI/litellm/pull/9709)
|
||||||
|
|
||||||
|
## Helm
|
||||||
|
1. fix: wrong indentation of ttlSecondsAfterFinished in chart [PR](https://github.com/BerriAI/litellm/pull/9611)
|
||||||
|
|
||||||
|
## General Proxy Improvements
|
||||||
|
1. Fix - only apply service_account_settings.enforced_params on service accounts [PR](https://github.com/BerriAI/litellm/pull/9683)
|
||||||
|
2. Fix - handle metadata null on `/chat/completion` [PR](https://github.com/BerriAI/litellm/issues/9717)
|
||||||
|
3. Fix - Move daily user transaction logging outside of 'disable_spend_logs' flag, as they’re unrelated [PR](https://github.com/BerriAI/litellm/pull/9772)
|
||||||
|
|
||||||
|
## Demo
|
||||||
|
|
||||||
|
Try this on the demo instance [today](https://docs.litellm.ai/docs/proxy/demo)
|
||||||
|
|
||||||
|
## Complete Git Diff
|
||||||
|
|
||||||
|
See the complete git diff since v1.65.0-stable, [here](https://github.com/BerriAI/litellm/releases/tag/v1.65.4-stable)
|
||||||
|
|
|
@ -53,7 +53,7 @@ const sidebars = {
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Architecture",
|
label: "Architecture",
|
||||||
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch", "proxy/image_handling"],
|
items: ["proxy/architecture", "proxy/db_info", "proxy/db_deadlocks", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch", "proxy/image_handling"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "link",
|
type: "link",
|
||||||
|
@ -188,7 +188,15 @@ const sidebars = {
|
||||||
"providers/azure_ai",
|
"providers/azure_ai",
|
||||||
"providers/aiml",
|
"providers/aiml",
|
||||||
"providers/vertex",
|
"providers/vertex",
|
||||||
"providers/gemini",
|
|
||||||
|
{
|
||||||
|
type: "category",
|
||||||
|
label: "Google AI Studio",
|
||||||
|
items: [
|
||||||
|
"providers/gemini",
|
||||||
|
"providers/google_ai_studio/files",
|
||||||
|
]
|
||||||
|
},
|
||||||
"providers/anthropic",
|
"providers/anthropic",
|
||||||
"providers/aws_sagemaker",
|
"providers/aws_sagemaker",
|
||||||
"providers/bedrock",
|
"providers/bedrock",
|
||||||
|
|
BIN
litellm-proxy-extras/dist/litellm_proxy_extras-0.1.3-py3-none-any.whl
vendored
Normal file
BIN
litellm-proxy-extras/dist/litellm_proxy_extras-0.1.3-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
litellm-proxy-extras/dist/litellm_proxy_extras-0.1.3.tar.gz
vendored
Normal file
BIN
litellm-proxy-extras/dist/litellm_proxy_extras-0.1.3.tar.gz
vendored
Normal file
Binary file not shown.
356
litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Normal file
356
litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Normal file
|
@ -0,0 +1,356 @@
|
||||||
|
datasource client {
|
||||||
|
provider = "postgresql"
|
||||||
|
url = env("DATABASE_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
generator client {
|
||||||
|
provider = "prisma-client-py"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget / Rate Limits for an org
|
||||||
|
model LiteLLM_BudgetTable {
|
||||||
|
budget_id String @id @default(uuid())
|
||||||
|
max_budget Float?
|
||||||
|
soft_budget Float?
|
||||||
|
max_parallel_requests Int?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
model_max_budget Json?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
|
||||||
|
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
|
||||||
|
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||||
|
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
|
||||||
|
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
|
||||||
|
}
|
||||||
|
|
||||||
|
// Models on proxy
|
||||||
|
model LiteLLM_CredentialsTable {
|
||||||
|
credential_id String @id @default(uuid())
|
||||||
|
credential_name String @unique
|
||||||
|
credential_values Json
|
||||||
|
credential_info Json?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
|
// Models on proxy
|
||||||
|
model LiteLLM_ProxyModelTable {
|
||||||
|
model_id String @id @default(uuid())
|
||||||
|
model_name String
|
||||||
|
litellm_params Json
|
||||||
|
model_info Json?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_OrganizationTable {
|
||||||
|
organization_id String @id @default(uuid())
|
||||||
|
organization_alias String
|
||||||
|
budget_id String
|
||||||
|
metadata Json @default("{}")
|
||||||
|
models String[]
|
||||||
|
spend Float @default(0.0)
|
||||||
|
model_spend Json @default("{}")
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
|
teams LiteLLM_TeamTable[]
|
||||||
|
users LiteLLM_UserTable[]
|
||||||
|
keys LiteLLM_VerificationToken[]
|
||||||
|
members LiteLLM_OrganizationMembership[] @relation("OrganizationToMembership")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model info for teams, just has model aliases for now.
|
||||||
|
model LiteLLM_ModelTable {
|
||||||
|
id Int @id @default(autoincrement())
|
||||||
|
model_aliases Json? @map("aliases")
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
team LiteLLM_TeamTable?
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Assign prod keys to groups, not individuals
|
||||||
|
model LiteLLM_TeamTable {
|
||||||
|
team_id String @id @default(uuid())
|
||||||
|
team_alias String?
|
||||||
|
organization_id String?
|
||||||
|
admins String[]
|
||||||
|
members String[]
|
||||||
|
members_with_roles Json @default("{}")
|
||||||
|
metadata Json @default("{}")
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
models String[]
|
||||||
|
max_parallel_requests Int?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
blocked Boolean @default(false)
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
model_spend Json @default("{}")
|
||||||
|
model_max_budget Json @default("{}")
|
||||||
|
model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases
|
||||||
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
|
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track spend, rate limit, budget Users
|
||||||
|
model LiteLLM_UserTable {
|
||||||
|
user_id String @id
|
||||||
|
user_alias String?
|
||||||
|
team_id String?
|
||||||
|
sso_user_id String? @unique
|
||||||
|
organization_id String?
|
||||||
|
password String?
|
||||||
|
teams String[] @default([])
|
||||||
|
user_role String?
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
user_email String?
|
||||||
|
models String[]
|
||||||
|
metadata Json @default("{}")
|
||||||
|
max_parallel_requests Int?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
|
model_spend Json @default("{}")
|
||||||
|
model_max_budget Json @default("{}")
|
||||||
|
created_at DateTime? @default(now()) @map("created_at")
|
||||||
|
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||||
|
|
||||||
|
// relations
|
||||||
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
|
organization_memberships LiteLLM_OrganizationMembership[]
|
||||||
|
invitations_created LiteLLM_InvitationLink[] @relation("CreatedBy")
|
||||||
|
invitations_updated LiteLLM_InvitationLink[] @relation("UpdatedBy")
|
||||||
|
invitations_user LiteLLM_InvitationLink[] @relation("UserId")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate Tokens for Proxy
|
||||||
|
model LiteLLM_VerificationToken {
|
||||||
|
token String @id
|
||||||
|
key_name String?
|
||||||
|
key_alias String?
|
||||||
|
soft_budget_cooldown Boolean @default(false) // key-level state on if budget alerts need to be cooled down
|
||||||
|
spend Float @default(0.0)
|
||||||
|
expires DateTime?
|
||||||
|
models String[]
|
||||||
|
aliases Json @default("{}")
|
||||||
|
config Json @default("{}")
|
||||||
|
user_id String?
|
||||||
|
team_id String?
|
||||||
|
permissions Json @default("{}")
|
||||||
|
max_parallel_requests Int?
|
||||||
|
metadata Json @default("{}")
|
||||||
|
blocked Boolean?
|
||||||
|
tpm_limit BigInt?
|
||||||
|
rpm_limit BigInt?
|
||||||
|
max_budget Float?
|
||||||
|
budget_duration String?
|
||||||
|
budget_reset_at DateTime?
|
||||||
|
allowed_cache_controls String[] @default([])
|
||||||
|
model_spend Json @default("{}")
|
||||||
|
model_max_budget Json @default("{}")
|
||||||
|
budget_id String?
|
||||||
|
organization_id String?
|
||||||
|
created_at DateTime? @default(now()) @map("created_at")
|
||||||
|
created_by String?
|
||||||
|
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String?
|
||||||
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_EndUserTable {
|
||||||
|
user_id String @id
|
||||||
|
alias String? // admin-facing alias
|
||||||
|
spend Float @default(0.0)
|
||||||
|
allowed_model_region String? // require all user requests to use models in this specific region
|
||||||
|
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||||
|
budget_id String?
|
||||||
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
|
blocked Boolean @default(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// store proxy config.yaml
|
||||||
|
model LiteLLM_Config {
|
||||||
|
param_name String @id
|
||||||
|
param_value Json?
|
||||||
|
}
|
||||||
|
|
||||||
|
// View spend, model, api_key per request
|
||||||
|
model LiteLLM_SpendLogs {
|
||||||
|
request_id String @id
|
||||||
|
call_type String
|
||||||
|
api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken
|
||||||
|
spend Float @default(0.0)
|
||||||
|
total_tokens Int @default(0)
|
||||||
|
prompt_tokens Int @default(0)
|
||||||
|
completion_tokens Int @default(0)
|
||||||
|
startTime DateTime // Assuming start_time is a DateTime field
|
||||||
|
endTime DateTime // Assuming end_time is a DateTime field
|
||||||
|
completionStartTime DateTime? // Assuming completionStartTime is a DateTime field
|
||||||
|
model String @default("")
|
||||||
|
model_id String? @default("") // the model id stored in proxy model db
|
||||||
|
model_group String? @default("") // public model_name / model_group
|
||||||
|
custom_llm_provider String? @default("") // litellm used custom_llm_provider
|
||||||
|
api_base String? @default("")
|
||||||
|
user String? @default("")
|
||||||
|
metadata Json? @default("{}")
|
||||||
|
cache_hit String? @default("")
|
||||||
|
cache_key String? @default("")
|
||||||
|
request_tags Json? @default("[]")
|
||||||
|
team_id String?
|
||||||
|
end_user String?
|
||||||
|
requester_ip_address String?
|
||||||
|
messages Json? @default("{}")
|
||||||
|
response Json? @default("{}")
|
||||||
|
@@index([startTime])
|
||||||
|
@@index([end_user])
|
||||||
|
}
|
||||||
|
|
||||||
|
// View spend, model, api_key per request
|
||||||
|
model LiteLLM_ErrorLogs {
|
||||||
|
request_id String @id @default(uuid())
|
||||||
|
startTime DateTime // Assuming start_time is a DateTime field
|
||||||
|
endTime DateTime // Assuming end_time is a DateTime field
|
||||||
|
api_base String @default("")
|
||||||
|
model_group String @default("") // public model_name / model_group
|
||||||
|
litellm_model_name String @default("") // model passed to litellm
|
||||||
|
model_id String @default("") // ID of model in ProxyModelTable
|
||||||
|
request_kwargs Json @default("{}")
|
||||||
|
exception_type String @default("")
|
||||||
|
exception_string String @default("")
|
||||||
|
status_code String @default("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Beta - allow team members to request access to a model
|
||||||
|
model LiteLLM_UserNotifications {
|
||||||
|
request_id String @id
|
||||||
|
user_id String
|
||||||
|
models String[]
|
||||||
|
justification String
|
||||||
|
status String // approved, disapproved, pending
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_TeamMembership {
|
||||||
|
// Use this table to track the Internal User's Spend within a Team + Set Budgets, rpm limits for the user within the team
|
||||||
|
user_id String
|
||||||
|
team_id String
|
||||||
|
spend Float @default(0.0)
|
||||||
|
budget_id String?
|
||||||
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
|
@@id([user_id, team_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_OrganizationMembership {
|
||||||
|
// Use this table to track Internal User and Organization membership. Helps tracking a users role within an Organization
|
||||||
|
user_id String
|
||||||
|
organization_id String
|
||||||
|
user_role String?
|
||||||
|
spend Float? @default(0.0)
|
||||||
|
budget_id String?
|
||||||
|
created_at DateTime? @default(now()) @map("created_at")
|
||||||
|
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||||
|
|
||||||
|
// relations
|
||||||
|
user LiteLLM_UserTable @relation(fields: [user_id], references: [user_id])
|
||||||
|
organization LiteLLM_OrganizationTable @relation("OrganizationToMembership", fields: [organization_id], references: [organization_id])
|
||||||
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@id([user_id, organization_id])
|
||||||
|
@@unique([user_id, organization_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_InvitationLink {
|
||||||
|
// use this table to track invite links sent by admin for people to join the proxy
|
||||||
|
id String @id @default(uuid())
|
||||||
|
user_id String
|
||||||
|
is_accepted Boolean @default(false)
|
||||||
|
accepted_at DateTime? // when link is claimed (user successfully onboards via link)
|
||||||
|
expires_at DateTime // till when is link valid
|
||||||
|
created_at DateTime // when did admin create the link
|
||||||
|
created_by String // who created the link
|
||||||
|
updated_at DateTime // when was invite status updated
|
||||||
|
updated_by String // who updated the status (admin/user who accepted invite)
|
||||||
|
|
||||||
|
// Relations
|
||||||
|
liteLLM_user_table_user LiteLLM_UserTable @relation("UserId", fields: [user_id], references: [user_id])
|
||||||
|
liteLLM_user_table_created LiteLLM_UserTable @relation("CreatedBy", fields: [created_by], references: [user_id])
|
||||||
|
liteLLM_user_table_updated LiteLLM_UserTable @relation("UpdatedBy", fields: [updated_by], references: [user_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
model LiteLLM_AuditLog {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
updated_at DateTime @default(now())
|
||||||
|
changed_by String @default("") // user or system that performed the action
|
||||||
|
changed_by_api_key String @default("") // api key hash that performed the action
|
||||||
|
action String // create, update, delete
|
||||||
|
table_name String // on of LitellmTableNames.TEAM_TABLE_NAME, LitellmTableNames.USER_TABLE_NAME, LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||||
|
object_id String // id of the object being audited. This can be the key id, team id, user id, model id
|
||||||
|
before_value Json? // value of the row
|
||||||
|
updated_values Json? // value of the row after change
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track daily user spend metrics per model and key
|
||||||
|
model LiteLLM_DailyUserSpend {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
user_id String
|
||||||
|
date String
|
||||||
|
api_key String
|
||||||
|
model String
|
||||||
|
model_group String?
|
||||||
|
custom_llm_provider String?
|
||||||
|
prompt_tokens Int @default(0)
|
||||||
|
completion_tokens Int @default(0)
|
||||||
|
spend Float @default(0.0)
|
||||||
|
api_requests Int @default(0)
|
||||||
|
successful_requests Int @default(0)
|
||||||
|
failed_requests Int @default(0)
|
||||||
|
created_at DateTime @default(now())
|
||||||
|
updated_at DateTime @updatedAt
|
||||||
|
|
||||||
|
@@unique([user_id, date, api_key, model, custom_llm_provider])
|
||||||
|
@@index([date])
|
||||||
|
@@index([user_id])
|
||||||
|
@@index([api_key])
|
||||||
|
@@index([model])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Track the status of cron jobs running. Only allow one pod to run the job at a time
|
||||||
|
model LiteLLM_CronJob {
|
||||||
|
cronjob_id String @id @default(cuid()) // Unique ID for the record
|
||||||
|
pod_id String // Unique identifier for the pod acting as the leader
|
||||||
|
status JobStatus @default(INACTIVE) // Status of the cron job (active or inactive)
|
||||||
|
last_updated DateTime @default(now()) // Timestamp for the last update of the cron job record
|
||||||
|
ttl DateTime // Time when the leader's lease expires
|
||||||
|
}
|
||||||
|
|
||||||
|
enum JobStatus {
|
||||||
|
ACTIVE
|
||||||
|
INACTIVE
|
||||||
|
}
|
||||||
|
|
|
@ -30,21 +30,23 @@ class ProxyExtrasDBManager:
|
||||||
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate
|
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate
|
||||||
for attempt in range(4):
|
for attempt in range(4):
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
schema_dir = os.path.dirname(schema_path)
|
migrations_dir = os.path.dirname(__file__)
|
||||||
os.chdir(schema_dir)
|
os.chdir(migrations_dir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if use_migrate:
|
if use_migrate:
|
||||||
logger.info("Running prisma migrate deploy")
|
logger.info("Running prisma migrate deploy")
|
||||||
try:
|
try:
|
||||||
# Set migrations directory for Prisma
|
# Set migrations directory for Prisma
|
||||||
subprocess.run(
|
result = subprocess.run(
|
||||||
["prisma", "migrate", "deploy"],
|
["prisma", "migrate", "deploy"],
|
||||||
timeout=60,
|
timeout=60,
|
||||||
check=True,
|
check=True,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
|
logger.info(f"prisma migrate deploy stdout: {result.stdout}")
|
||||||
|
|
||||||
logger.info("prisma migrate deploy completed")
|
logger.info("prisma migrate deploy completed")
|
||||||
return True
|
return True
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
|
@ -77,4 +79,5 @@ class ProxyExtrasDBManager:
|
||||||
time.sleep(random.randrange(5, 15))
|
time.sleep(random.randrange(5, 15))
|
||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm-proxy-extras"
|
name = "litellm-proxy-extras"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package."
|
description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package."
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -22,7 +22,7 @@ requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:version",
|
"pyproject.toml:version",
|
||||||
"../requirements.txt:litellm-proxy-extras==",
|
"../requirements.txt:litellm-proxy-extras==",
|
||||||
|
|
|
@ -56,6 +56,9 @@ from litellm.constants import (
|
||||||
bedrock_embedding_models,
|
bedrock_embedding_models,
|
||||||
known_tokenizer_config,
|
known_tokenizer_config,
|
||||||
BEDROCK_INVOKE_PROVIDERS_LITERAL,
|
BEDROCK_INVOKE_PROVIDERS_LITERAL,
|
||||||
|
DEFAULT_MAX_TOKENS,
|
||||||
|
DEFAULT_SOFT_BUDGET,
|
||||||
|
DEFAULT_ALLOWED_FAILS,
|
||||||
)
|
)
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
@ -120,6 +123,7 @@ callbacks: List[
|
||||||
langfuse_default_tags: Optional[List[str]] = None
|
langfuse_default_tags: Optional[List[str]] = None
|
||||||
langsmith_batch_size: Optional[int] = None
|
langsmith_batch_size: Optional[int] = None
|
||||||
prometheus_initialize_budget_metrics: Optional[bool] = False
|
prometheus_initialize_budget_metrics: Optional[bool] = False
|
||||||
|
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||||
argilla_batch_size: Optional[int] = None
|
argilla_batch_size: Optional[int] = None
|
||||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||||
gcs_pub_sub_use_v1: Optional[
|
gcs_pub_sub_use_v1: Optional[
|
||||||
|
@ -155,7 +159,7 @@ token: Optional[
|
||||||
str
|
str
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
telemetry = True
|
telemetry = True
|
||||||
max_tokens = 256 # OpenAI Defaults
|
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||||
modify_params = False
|
modify_params = False
|
||||||
retry = True
|
retry = True
|
||||||
|
@ -244,7 +248,7 @@ budget_duration: Optional[
|
||||||
str
|
str
|
||||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
default_soft_budget: float = (
|
default_soft_budget: float = (
|
||||||
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||||
)
|
)
|
||||||
forward_traceparent_to_llm_provider: bool = False
|
forward_traceparent_to_llm_provider: bool = False
|
||||||
|
|
||||||
|
@ -796,9 +800,8 @@ from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
||||||
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
||||||
from .llms.github.chat.transformation import GithubChatConfig
|
from .llms.github.chat.transformation import GithubChatConfig
|
||||||
from .llms.empower.chat.transformation import EmpowerChatConfig
|
from .llms.empower.chat.transformation import EmpowerChatConfig
|
||||||
from .llms.huggingface.chat.transformation import (
|
from .llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
||||||
HuggingfaceChatConfig as HuggingfaceConfig,
|
from .llms.huggingface.embedding.transformation import HuggingFaceEmbeddingConfig
|
||||||
)
|
|
||||||
from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
||||||
from .llms.maritalk import MaritalkConfig
|
from .llms.maritalk import MaritalkConfig
|
||||||
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||||
|
|
|
@ -18,6 +18,7 @@ import redis # type: ignore
|
||||||
import redis.asyncio as async_redis # type: ignore
|
import redis.asyncio as async_redis # type: ignore
|
||||||
|
|
||||||
from litellm import get_secret, get_secret_str
|
from litellm import get_secret, get_secret_str
|
||||||
|
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
|
||||||
|
@ -215,7 +216,7 @@ def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||||
# Set up the Sentinel client
|
# Set up the Sentinel client
|
||||||
sentinel = redis.Sentinel(
|
sentinel = redis.Sentinel(
|
||||||
sentinel_nodes,
|
sentinel_nodes,
|
||||||
socket_timeout=0.1,
|
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||||
password=sentinel_password,
|
password=sentinel_password,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -239,7 +240,7 @@ def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
||||||
# Set up the Sentinel client
|
# Set up the Sentinel client
|
||||||
sentinel = async_redis.Sentinel(
|
sentinel = async_redis.Sentinel(
|
||||||
sentinel_nodes,
|
sentinel_nodes,
|
||||||
socket_timeout=0.1,
|
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||||
password=sentinel_password,
|
password=sentinel_password,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -319,7 +320,7 @@ def get_redis_connection_pool(**env_overrides):
|
||||||
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
||||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
return async_redis.BlockingConnectionPool.from_url(
|
return async_redis.BlockingConnectionPool.from_url(
|
||||||
timeout=5, url=redis_kwargs["url"]
|
timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
|
||||||
)
|
)
|
||||||
connection_class = async_redis.Connection
|
connection_class = async_redis.Connection
|
||||||
if "ssl" in redis_kwargs:
|
if "ssl" in redis_kwargs:
|
||||||
|
@ -327,4 +328,6 @@ def get_redis_connection_pool(**env_overrides):
|
||||||
redis_kwargs.pop("ssl", None)
|
redis_kwargs.pop("ssl", None)
|
||||||
redis_kwargs["connection_class"] = connection_class
|
redis_kwargs["connection_class"] = connection_class
|
||||||
redis_kwargs.pop("startup_nodes", None)
|
redis_kwargs.pop("startup_nodes", None)
|
||||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
return async_redis.BlockingConnectionPool(
|
||||||
|
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
|
||||||
|
)
|
||||||
|
|
|
@ -124,6 +124,7 @@ class ServiceLogging(CustomLogger):
|
||||||
service=service,
|
service=service,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
|
event_metadata=event_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
for callback in litellm.service_callback:
|
for callback in litellm.service_callback:
|
||||||
|
@ -229,6 +230,7 @@ class ServiceLogging(CustomLogger):
|
||||||
service=service,
|
service=service,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
|
event_metadata=event_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
for callback in litellm.service_callback:
|
for callback in litellm.service_callback:
|
||||||
|
|
|
@ -14,6 +14,12 @@ import time
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import (
|
||||||
|
DAYS_IN_A_MONTH,
|
||||||
|
DAYS_IN_A_WEEK,
|
||||||
|
DAYS_IN_A_YEAR,
|
||||||
|
HOURS_IN_A_DAY,
|
||||||
|
)
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,11 +87,11 @@ class BudgetManager:
|
||||||
if duration == "daily":
|
if duration == "daily":
|
||||||
duration_in_days = 1
|
duration_in_days = 1
|
||||||
elif duration == "weekly":
|
elif duration == "weekly":
|
||||||
duration_in_days = 7
|
duration_in_days = DAYS_IN_A_WEEK
|
||||||
elif duration == "monthly":
|
elif duration == "monthly":
|
||||||
duration_in_days = 28
|
duration_in_days = DAYS_IN_A_MONTH
|
||||||
elif duration == "yearly":
|
elif duration == "yearly":
|
||||||
duration_in_days = 365
|
duration_in_days = DAYS_IN_A_YEAR
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
||||||
|
@ -182,7 +188,9 @@ class BudgetManager:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# Convert duration from days to seconds
|
# Convert duration from days to seconds
|
||||||
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
|
duration_in_seconds = (
|
||||||
|
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
|
||||||
|
)
|
||||||
|
|
||||||
# Check if duration has elapsed
|
# Check if duration has elapsed
|
||||||
if current_time - last_updated_at >= duration_in_seconds:
|
if current_time - last_updated_at >= duration_in_seconds:
|
||||||
|
|
|
@ -19,6 +19,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||||
from litellm.types.caching import *
|
from litellm.types.caching import *
|
||||||
from litellm.types.utils import all_litellm_params
|
from litellm.types.utils import all_litellm_params
|
||||||
|
@ -406,7 +407,7 @@ class Cache:
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
time.sleep(0.02)
|
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||||
|
|
||||||
def _get_cache_logic(
|
def _get_cache_logic(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -15,7 +15,8 @@ from typing import Any, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +53,8 @@ class InMemoryCache(BaseCache):
|
||||||
# Fast path for common primitive types that are typically small
|
# Fast path for common primitive types that are typically small
|
||||||
if (
|
if (
|
||||||
isinstance(value, (bool, int, float, str))
|
isinstance(value, (bool, int, float, str))
|
||||||
and len(str(value)) < self.max_size_per_item * 512
|
and len(str(value))
|
||||||
|
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||||
): # Conservative estimate
|
): # Conservative estimate
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,12 @@ Has 4 methods:
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose
|
from litellm._logging import print_verbose
|
||||||
|
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
|
||||||
|
@ -118,7 +120,11 @@ class QdrantSemanticCache(BaseCache):
|
||||||
}
|
}
|
||||||
elif quantization_config == "scalar":
|
elif quantization_config == "scalar":
|
||||||
quantization_params = {
|
quantization_params = {
|
||||||
"scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
|
"scalar": {
|
||||||
|
"type": "int8",
|
||||||
|
"quantile": QDRANT_SCALAR_QUANTILE,
|
||||||
|
"always_ram": False,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
elif quantization_config == "product":
|
elif quantization_config == "product":
|
||||||
quantization_params = {
|
quantization_params = {
|
||||||
|
@ -132,7 +138,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
new_collection_status = self.sync_client.put(
|
new_collection_status = self.sync_client.put(
|
||||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||||
json={
|
json={
|
||||||
"vectors": {"size": 1536, "distance": "Cosine"},
|
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
|
||||||
"quantization_config": quantization_params,
|
"quantization_config": quantization_params,
|
||||||
},
|
},
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
|
@ -171,10 +177,13 @@ class QdrantSemanticCache(BaseCache):
|
||||||
prompt += message["content"]
|
prompt += message["content"]
|
||||||
|
|
||||||
# create an embedding for prompt
|
# create an embedding for prompt
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = cast(
|
||||||
model=self.embedding_model,
|
EmbeddingResponse,
|
||||||
input=prompt,
|
litellm.embedding(
|
||||||
cache={"no-store": True, "no-cache": True},
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the embedding
|
# get the embedding
|
||||||
|
@ -212,10 +221,13 @@ class QdrantSemanticCache(BaseCache):
|
||||||
prompt += message["content"]
|
prompt += message["content"]
|
||||||
|
|
||||||
# convert to embedding
|
# convert to embedding
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = cast(
|
||||||
model=self.embedding_model,
|
EmbeddingResponse,
|
||||||
input=prompt,
|
litellm.embedding(
|
||||||
cache={"no-store": True, "no-cache": True},
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the embedding
|
# get the embedding
|
||||||
|
|
|
@ -304,12 +304,18 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
key = self.check_and_fix_namespace(key=key)
|
key = self.check_and_fix_namespace(key=key)
|
||||||
ttl = self.get_ttl(**kwargs)
|
ttl = self.get_ttl(**kwargs)
|
||||||
|
nx = kwargs.get("nx", False)
|
||||||
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not hasattr(_redis_client, "set"):
|
if not hasattr(_redis_client, "set"):
|
||||||
raise Exception("Redis client cannot set cache. Attribute not found.")
|
raise Exception("Redis client cannot set cache. Attribute not found.")
|
||||||
await _redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
result = await _redis_client.set(
|
||||||
|
name=key,
|
||||||
|
value=json.dumps(value),
|
||||||
|
nx=nx,
|
||||||
|
ex=ttl,
|
||||||
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
@ -326,6 +332,7 @@ class RedisCache(BaseCache):
|
||||||
event_metadata={"key": key},
|
event_metadata={"key": key},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
@ -931,7 +938,7 @@ class RedisCache(BaseCache):
|
||||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||||
_redis_client: Any = self.init_async_client()
|
_redis_client: Any = self.init_async_client()
|
||||||
# keys is str
|
# keys is str
|
||||||
await _redis_client.delete(key)
|
return await _redis_client.delete(key)
|
||||||
|
|
||||||
def delete_cache(self, key):
|
def delete_cache(self, key):
|
||||||
self.redis_client.delete(key)
|
self.redis_client.delete(key)
|
||||||
|
|
|
@ -9,6 +9,7 @@ DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||||
)
|
)
|
||||||
DEFAULT_MAX_TOKENS = 4096
|
DEFAULT_MAX_TOKENS = 4096
|
||||||
|
DEFAULT_ALLOWED_FAILS = 3
|
||||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||||
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
||||||
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||||
|
@ -16,16 +17,76 @@ DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
||||||
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
||||||
DEFAULT_IMAGE_WIDTH = 300
|
DEFAULT_IMAGE_WIDTH = 300
|
||||||
DEFAULT_IMAGE_HEIGHT = 300
|
DEFAULT_IMAGE_HEIGHT = 300
|
||||||
|
DEFAULT_MAX_TOKENS = 256 # used when providers need a default
|
||||||
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
|
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
|
||||||
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
|
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
|
||||||
|
|
||||||
|
########### v2 Architecture constants for managing writing updates to the database ###########
|
||||||
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
|
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
|
||||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
|
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
|
||||||
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
|
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
|
||||||
|
MAX_SIZE_IN_MEMORY_QUEUE = 10000
|
||||||
|
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
|
||||||
|
###############################################################################################
|
||||||
|
MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
|
||||||
|
1024 # minimum number of tokens to cache a prompt by Anthropic
|
||||||
|
)
|
||||||
|
DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
|
||||||
|
HOURS_IN_A_DAY = 24
|
||||||
|
DAYS_IN_A_WEEK = 7
|
||||||
|
DAYS_IN_A_MONTH = 28
|
||||||
|
DAYS_IN_A_YEAR = 365
|
||||||
|
REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
|
||||||
|
#### TOKEN COUNTING ####
|
||||||
|
FUNCTION_DEFINITION_TOKEN_COUNT = 9
|
||||||
|
SYSTEM_MESSAGE_TOKEN_COUNT = 4
|
||||||
|
TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
|
||||||
|
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
|
||||||
|
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
|
||||||
|
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
|
||||||
|
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
|
||||||
|
MAX_TILE_WIDTH = 512
|
||||||
|
MAX_TILE_HEIGHT = 512
|
||||||
|
OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
|
||||||
|
MIN_NON_ZERO_TEMPERATURE = 0.0001
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||||
|
DEFAULT_MAX_LRU_CACHE_SIZE = 16
|
||||||
|
INITIAL_RETRY_DELAY = 0.5
|
||||||
|
MAX_RETRY_DELAY = 8.0
|
||||||
|
JITTER = 0.75
|
||||||
|
DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
|
||||||
|
DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
|
||||||
|
AZURE_OPERATION_POLLING_TIMEOUT = 120
|
||||||
|
REDIS_SOCKET_TIMEOUT = 0.1
|
||||||
|
REDIS_CONNECTION_POOL_TIMEOUT = 5
|
||||||
|
NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
|
||||||
|
MAX_EXCEPTION_MESSAGE_LENGTH = 2000
|
||||||
|
BEDROCK_MAX_POLICY_SIZE = 75
|
||||||
|
REPLICATE_POLLING_DELAY_SECONDS = 0.5
|
||||||
|
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
|
||||||
|
TOGETHER_AI_4_B = 4
|
||||||
|
TOGETHER_AI_8_B = 8
|
||||||
|
TOGETHER_AI_21_B = 21
|
||||||
|
TOGETHER_AI_41_B = 41
|
||||||
|
TOGETHER_AI_80_B = 80
|
||||||
|
TOGETHER_AI_110_B = 110
|
||||||
|
TOGETHER_AI_EMBEDDING_150_M = 150
|
||||||
|
TOGETHER_AI_EMBEDDING_350_M = 350
|
||||||
|
QDRANT_SCALAR_QUANTILE = 0.99
|
||||||
|
QDRANT_VECTOR_SIZE = 1536
|
||||||
|
CACHED_STREAMING_CHUNK_DELAY = 0.02
|
||||||
|
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
|
||||||
|
DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
|
||||||
#### Networking settings ####
|
#### Networking settings ####
|
||||||
request_timeout: float = 6000 # time in seconds
|
request_timeout: float = 6000 # time in seconds
|
||||||
STREAM_SSE_DONE_STRING: str = "[DONE]"
|
STREAM_SSE_DONE_STRING: str = "[DONE]"
|
||||||
|
### SPEND TRACKING ###
|
||||||
|
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
|
||||||
|
FIREWORKS_AI_56_B_MOE = 56
|
||||||
|
FIREWORKS_AI_176_B_MOE = 176
|
||||||
|
FIREWORKS_AI_16_B = 16
|
||||||
|
FIREWORKS_AI_80_B = 80
|
||||||
|
|
||||||
LITELLM_CHAT_PROVIDERS = [
|
LITELLM_CHAT_PROVIDERS = [
|
||||||
"openai",
|
"openai",
|
||||||
|
@ -426,6 +487,9 @@ MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||||
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
||||||
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
||||||
)
|
)
|
||||||
|
DEFAULT_SOFT_BUDGET = (
|
||||||
|
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
||||||
|
)
|
||||||
# makes it clear this is a rate limit error for a litellm virtual key
|
# makes it clear this is a rate limit error for a litellm virtual key
|
||||||
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
||||||
|
|
||||||
|
@ -451,3 +515,14 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
|
||||||
########################### DB CRON JOB NAMES ###########################
|
########################### DB CRON JOB NAMES ###########################
|
||||||
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
||||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
||||||
|
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
||||||
|
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
||||||
|
PROXY_BATCH_WRITE_AT = 10 # in seconds
|
||||||
|
DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
|
||||||
|
PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
|
||||||
|
DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
|
||||||
|
DEFAULT_SLACK_ALERTING_THRESHOLD = 300
|
||||||
|
MAX_TEAM_LIST_LIMIT = 20
|
||||||
|
DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
|
||||||
|
LENGTH_OF_LITELLM_GENERATED_KEY = 16
|
||||||
|
SECRET_MANAGER_REFRESH_INTERVAL = 86400
|
||||||
|
|
|
@ -9,6 +9,10 @@ from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
import litellm._logging
|
import litellm._logging
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.constants import (
|
||||||
|
DEFAULT_MAX_LRU_CACHE_SIZE,
|
||||||
|
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||||
StandardBuiltInToolCostTracking,
|
StandardBuiltInToolCostTracking,
|
||||||
)
|
)
|
||||||
|
@ -355,9 +359,7 @@ def cost_per_token( # noqa: PLR0915
|
||||||
def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
|
def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
|
||||||
# see https://replicate.com/pricing
|
# see https://replicate.com/pricing
|
||||||
# for all litellm currently supported LLMs, almost all requests go to a100_80gb
|
# for all litellm currently supported LLMs, almost all requests go to a100_80gb
|
||||||
a100_80gb_price_per_second_public = (
|
a100_80gb_price_per_second_public = DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND # assume all calls sent to A100 80GB for now
|
||||||
0.001400 # assume all calls sent to A100 80GB for now
|
|
||||||
)
|
|
||||||
if total_time == 0.0: # total time is in ms
|
if total_time == 0.0: # total time is in ms
|
||||||
start_time = completion_response.get("created", time.time())
|
start_time = completion_response.get("created", time.time())
|
||||||
end_time = getattr(completion_response, "ended", time.time())
|
end_time = getattr(completion_response, "ended", time.time())
|
||||||
|
@ -450,7 +452,7 @@ def _select_model_name_for_cost_calc(
|
||||||
return return_model
|
return return_model
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=16)
|
@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
|
||||||
def _model_contains_known_llm_provider(model: str) -> bool:
|
def _model_contains_known_llm_provider(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model contains a known llm provider
|
Check if the model contains a known llm provider
|
||||||
|
|
|
@ -63,16 +63,17 @@ async def acreate_file(
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
kwargs["acreate_file"] = True
|
kwargs["acreate_file"] = True
|
||||||
|
|
||||||
# Use a partial function to pass your keyword arguments
|
call_args = {
|
||||||
func = partial(
|
"file": file,
|
||||||
create_file,
|
"purpose": purpose,
|
||||||
file,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
purpose,
|
"extra_headers": extra_headers,
|
||||||
custom_llm_provider,
|
"extra_body": extra_body,
|
||||||
extra_headers,
|
|
||||||
extra_body,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(create_file, **call_args)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
|
@ -92,7 +93,7 @@ async def acreate_file(
|
||||||
def create_file(
|
def create_file(
|
||||||
file: FileTypes,
|
file: FileTypes,
|
||||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -101,6 +102,8 @@ def create_file(
|
||||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||||
|
|
||||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||||
|
|
||||||
|
Specify either provider_list or custom_llm_provider.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_is_async = kwargs.pop("acreate_file", False) is True
|
_is_async = kwargs.pop("acreate_file", False) is True
|
||||||
|
@ -120,7 +123,7 @@ def create_file(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) is False
|
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
|
|
@ -16,6 +16,7 @@ import litellm.litellm_core_utils.litellm_logging
|
||||||
import litellm.types
|
import litellm.types
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.constants import HOURS_IN_A_DAY
|
||||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||||
|
@ -649,10 +650,10 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
event_message += (
|
event_message += (
|
||||||
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
||||||
)
|
)
|
||||||
elif percent_left <= 0.05:
|
elif percent_left <= SLACK_ALERTING_THRESHOLD_5_PERCENT:
|
||||||
event = "threshold_crossed"
|
event = "threshold_crossed"
|
||||||
event_message += "5% Threshold Crossed "
|
event_message += "5% Threshold Crossed "
|
||||||
elif percent_left <= 0.15:
|
elif percent_left <= SLACK_ALERTING_THRESHOLD_15_PERCENT:
|
||||||
event = "threshold_crossed"
|
event = "threshold_crossed"
|
||||||
event_message += "15% Threshold Crossed"
|
event_message += "15% Threshold Crossed"
|
||||||
elif user_info.soft_budget is not None:
|
elif user_info.soft_budget is not None:
|
||||||
|
@ -1718,7 +1719,7 @@ Model Info:
|
||||||
await self.internal_usage_cache.async_set_cache(
|
await self.internal_usage_cache.async_set_cache(
|
||||||
key=_event_cache_key,
|
key=_event_cache_key,
|
||||||
value="SENT",
|
value="SENT",
|
||||||
ttl=(30 * 24 * 60 * 60), # 1 month
|
ttl=(30 * HOURS_IN_A_DAY * 60 * 60), # 1 month
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -41,7 +41,7 @@ from litellm.types.utils import StandardLoggingPayload
|
||||||
from ..additional_logging_utils import AdditionalLoggingUtils
|
from ..additional_logging_utils import AdditionalLoggingUtils
|
||||||
|
|
||||||
# max number of logs DD API can accept
|
# max number of logs DD API can accept
|
||||||
DD_MAX_BATCH_SIZE = 1000
|
|
||||||
|
|
||||||
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
|
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
|
||||||
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
|
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
|
||||||
|
|
|
@ -20,10 +20,6 @@ else:
|
||||||
VertexBase = Any
|
VertexBase = Any
|
||||||
|
|
||||||
|
|
||||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
|
||||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
@ -125,6 +121,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||||
kwargs
|
kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = await self.construct_request_headers(
|
headers = await self.construct_request_headers(
|
||||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||||
service_account_json=gcs_logging_config["path_service_account"],
|
service_account_json=gcs_logging_config["path_service_account"],
|
||||||
|
|
|
@ -818,7 +818,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
requested_model=request_data.get("model", ""),
|
requested_model=request_data.get("model", ""),
|
||||||
status_code=str(getattr(original_exception, "status_code", None)),
|
status_code=str(getattr(original_exception, "status_code", None)),
|
||||||
exception_status=str(getattr(original_exception, "status_code", None)),
|
exception_status=str(getattr(original_exception, "status_code", None)),
|
||||||
exception_class=str(original_exception.__class__.__name__),
|
exception_class=self._get_exception_class_name(original_exception),
|
||||||
tags=_tags,
|
tags=_tags,
|
||||||
)
|
)
|
||||||
_labels = prometheus_label_factory(
|
_labels = prometheus_label_factory(
|
||||||
|
@ -917,7 +917,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_provider=llm_provider,
|
api_provider=llm_provider,
|
||||||
exception_status=str(getattr(exception, "status_code", None)),
|
exception_status=str(getattr(exception, "status_code", None)),
|
||||||
exception_class=exception.__class__.__name__,
|
exception_class=self._get_exception_class_name(exception),
|
||||||
requested_model=model_group,
|
requested_model=model_group,
|
||||||
hashed_api_key=standard_logging_payload["metadata"][
|
hashed_api_key=standard_logging_payload["metadata"][
|
||||||
"user_api_key_hash"
|
"user_api_key_hash"
|
||||||
|
@ -1146,6 +1146,22 @@ class PrometheusLogger(CustomLogger):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_exception_class_name(exception: Exception) -> str:
|
||||||
|
exception_class_name = ""
|
||||||
|
if hasattr(exception, "llm_provider"):
|
||||||
|
exception_class_name = getattr(exception, "llm_provider") or ""
|
||||||
|
|
||||||
|
# pretty print the provider name on prometheus
|
||||||
|
# eg. `openai` -> `Openai.`
|
||||||
|
if len(exception_class_name) >= 1:
|
||||||
|
exception_class_name = (
|
||||||
|
exception_class_name[0].upper() + exception_class_name[1:] + "."
|
||||||
|
)
|
||||||
|
|
||||||
|
exception_class_name += exception.__class__.__name__
|
||||||
|
return exception_class_name
|
||||||
|
|
||||||
async def log_success_fallback_event(
|
async def log_success_fallback_event(
|
||||||
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
||||||
):
|
):
|
||||||
|
@ -1181,7 +1197,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
team=standard_metadata["user_api_key_team_id"],
|
team=standard_metadata["user_api_key_team_id"],
|
||||||
team_alias=standard_metadata["user_api_key_team_alias"],
|
team_alias=standard_metadata["user_api_key_team_alias"],
|
||||||
exception_status=str(getattr(original_exception, "status_code", None)),
|
exception_status=str(getattr(original_exception, "status_code", None)),
|
||||||
exception_class=str(original_exception.__class__.__name__),
|
exception_class=self._get_exception_class_name(original_exception),
|
||||||
tags=_tags,
|
tags=_tags,
|
||||||
)
|
)
|
||||||
_labels = prometheus_label_factory(
|
_labels = prometheus_label_factory(
|
||||||
|
@ -1225,7 +1241,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
team=standard_metadata["user_api_key_team_id"],
|
team=standard_metadata["user_api_key_team_id"],
|
||||||
team_alias=standard_metadata["user_api_key_team_alias"],
|
team_alias=standard_metadata["user_api_key_team_alias"],
|
||||||
exception_status=str(getattr(original_exception, "status_code", None)),
|
exception_status=str(getattr(original_exception, "status_code", None)),
|
||||||
exception_class=str(original_exception.__class__.__name__),
|
exception_class=self._get_exception_class_name(original_exception),
|
||||||
tags=_tags,
|
tags=_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1721,6 +1737,36 @@ class PrometheusLogger(CustomLogger):
|
||||||
return (end_time - start_time).total_seconds()
|
return (end_time - start_time).total_seconds()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mount_metrics_endpoint(premium_user: bool):
|
||||||
|
"""
|
||||||
|
Mount the Prometheus metrics endpoint with optional authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
premium_user (bool): Whether the user is a premium user
|
||||||
|
require_auth (bool, optional): Whether to require authentication for the metrics endpoint.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
from prometheus_client import make_asgi_app
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import CommonProxyErrors
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
if premium_user is not True:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"Prometheus metrics are only available for premium users. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create metrics ASGI app
|
||||||
|
metrics_app = make_asgi_app()
|
||||||
|
|
||||||
|
# Mount the metrics app to the app
|
||||||
|
app.mount("/metrics", metrics_app)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Starting Prometheus Metrics on /metrics (no authentication)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prometheus_label_factory(
|
def prometheus_label_factory(
|
||||||
supported_enum_labels: List[str],
|
supported_enum_labels: List[str],
|
||||||
|
|
|
@ -3,11 +3,16 @@
|
||||||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
|
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import (
|
||||||
|
DEFAULT_SERVICE_CONFIGS,
|
||||||
|
ServiceLoggerPayload,
|
||||||
|
ServiceMetrics,
|
||||||
|
ServiceTypes,
|
||||||
|
)
|
||||||
|
|
||||||
FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
|
FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
|
||||||
|
|
||||||
|
@ -23,7 +28,8 @@ class PrometheusServicesLogger:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
from prometheus_client import REGISTRY, Counter, Histogram
|
from prometheus_client import REGISTRY, Counter, Gauge, Histogram
|
||||||
|
from prometheus_client.gc_collector import Collector
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Missing prometheus_client. Run `pip install prometheus-client`"
|
"Missing prometheus_client. Run `pip install prometheus-client`"
|
||||||
|
@ -31,36 +37,51 @@ class PrometheusServicesLogger:
|
||||||
|
|
||||||
self.Histogram = Histogram
|
self.Histogram = Histogram
|
||||||
self.Counter = Counter
|
self.Counter = Counter
|
||||||
|
self.Gauge = Gauge
|
||||||
self.REGISTRY = REGISTRY
|
self.REGISTRY = REGISTRY
|
||||||
|
|
||||||
verbose_logger.debug("in init prometheus services metrics")
|
verbose_logger.debug("in init prometheus services metrics")
|
||||||
|
|
||||||
self.services = [item.value for item in ServiceTypes]
|
self.payload_to_prometheus_map: Dict[
|
||||||
|
str, List[Union[Histogram, Counter, Gauge, Collector]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
self.payload_to_prometheus_map = (
|
for service in ServiceTypes:
|
||||||
{}
|
service_metrics: List[Union[Histogram, Counter, Gauge, Collector]] = []
|
||||||
) # store the prometheus histogram/counter we need to call for each field in payload
|
|
||||||
|
|
||||||
for service in self.services:
|
metrics_to_initialize = self._get_service_metrics_initialize(service)
|
||||||
histogram = self.create_histogram(service, type_of_request="latency")
|
|
||||||
counter_failed_request = self.create_counter(
|
|
||||||
service,
|
|
||||||
type_of_request="failed_requests",
|
|
||||||
additional_labels=FAILED_REQUESTS_LABELS,
|
|
||||||
)
|
|
||||||
counter_total_requests = self.create_counter(
|
|
||||||
service, type_of_request="total_requests"
|
|
||||||
)
|
|
||||||
self.payload_to_prometheus_map[service] = [
|
|
||||||
histogram,
|
|
||||||
counter_failed_request,
|
|
||||||
counter_total_requests,
|
|
||||||
]
|
|
||||||
|
|
||||||
self.prometheus_to_amount_map: dict = (
|
# Initialize only the configured metrics for each service
|
||||||
{}
|
if ServiceMetrics.HISTOGRAM in metrics_to_initialize:
|
||||||
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
|
histogram = self.create_histogram(
|
||||||
|
service.value, type_of_request="latency"
|
||||||
|
)
|
||||||
|
if histogram:
|
||||||
|
service_metrics.append(histogram)
|
||||||
|
|
||||||
|
if ServiceMetrics.COUNTER in metrics_to_initialize:
|
||||||
|
counter_failed_request = self.create_counter(
|
||||||
|
service.value,
|
||||||
|
type_of_request="failed_requests",
|
||||||
|
additional_labels=FAILED_REQUESTS_LABELS,
|
||||||
|
)
|
||||||
|
if counter_failed_request:
|
||||||
|
service_metrics.append(counter_failed_request)
|
||||||
|
counter_total_requests = self.create_counter(
|
||||||
|
service.value, type_of_request="total_requests"
|
||||||
|
)
|
||||||
|
if counter_total_requests:
|
||||||
|
service_metrics.append(counter_total_requests)
|
||||||
|
|
||||||
|
if ServiceMetrics.GAUGE in metrics_to_initialize:
|
||||||
|
gauge = self.create_gauge(service.value, type_of_request="size")
|
||||||
|
if gauge:
|
||||||
|
service_metrics.append(gauge)
|
||||||
|
|
||||||
|
if service_metrics:
|
||||||
|
self.payload_to_prometheus_map[service.value] = service_metrics
|
||||||
|
|
||||||
|
self.prometheus_to_amount_map: dict = {}
|
||||||
### MOCK TESTING ###
|
### MOCK TESTING ###
|
||||||
self.mock_testing = mock_testing
|
self.mock_testing = mock_testing
|
||||||
self.mock_testing_success_calls = 0
|
self.mock_testing_success_calls = 0
|
||||||
|
@ -70,6 +91,19 @@ class PrometheusServicesLogger:
|
||||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _get_service_metrics_initialize(
|
||||||
|
self, service: ServiceTypes
|
||||||
|
) -> List[ServiceMetrics]:
|
||||||
|
DEFAULT_METRICS = [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
|
||||||
|
if service not in DEFAULT_SERVICE_CONFIGS:
|
||||||
|
return DEFAULT_METRICS
|
||||||
|
|
||||||
|
metrics = DEFAULT_SERVICE_CONFIGS.get(service, {}).get("metrics", [])
|
||||||
|
if not metrics:
|
||||||
|
verbose_logger.debug(f"No metrics found for service {service}")
|
||||||
|
return DEFAULT_METRICS
|
||||||
|
return metrics
|
||||||
|
|
||||||
def is_metric_registered(self, metric_name) -> bool:
|
def is_metric_registered(self, metric_name) -> bool:
|
||||||
for metric in self.REGISTRY.collect():
|
for metric in self.REGISTRY.collect():
|
||||||
if metric_name == metric.name:
|
if metric_name == metric.name:
|
||||||
|
@ -94,6 +128,15 @@ class PrometheusServicesLogger:
|
||||||
buckets=LATENCY_BUCKETS,
|
buckets=LATENCY_BUCKETS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_gauge(self, service: str, type_of_request: str):
|
||||||
|
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||||
|
is_registered = self.is_metric_registered(metric_name)
|
||||||
|
if is_registered:
|
||||||
|
return self._get_metric(metric_name)
|
||||||
|
return self.Gauge(
|
||||||
|
metric_name, "Gauge for {} service".format(service), labelnames=[service]
|
||||||
|
)
|
||||||
|
|
||||||
def create_counter(
|
def create_counter(
|
||||||
self,
|
self,
|
||||||
service: str,
|
service: str,
|
||||||
|
@ -120,6 +163,15 @@ class PrometheusServicesLogger:
|
||||||
|
|
||||||
histogram.labels(labels).observe(amount)
|
histogram.labels(labels).observe(amount)
|
||||||
|
|
||||||
|
def update_gauge(
|
||||||
|
self,
|
||||||
|
gauge,
|
||||||
|
labels: str,
|
||||||
|
amount: float,
|
||||||
|
):
|
||||||
|
assert isinstance(gauge, self.Gauge)
|
||||||
|
gauge.labels(labels).set(amount)
|
||||||
|
|
||||||
def increment_counter(
|
def increment_counter(
|
||||||
self,
|
self,
|
||||||
counter,
|
counter,
|
||||||
|
@ -190,6 +242,13 @@ class PrometheusServicesLogger:
|
||||||
labels=payload.service.value,
|
labels=payload.service.value,
|
||||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||||
)
|
)
|
||||||
|
elif isinstance(obj, self.Gauge):
|
||||||
|
if payload.event_metadata:
|
||||||
|
self.update_gauge(
|
||||||
|
gauge=obj,
|
||||||
|
labels=payload.event_metadata.get("gauge_labels") or "",
|
||||||
|
amount=payload.event_metadata.get("gauge_value") or 0,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_service_failure_hook(
|
async def async_service_failure_hook(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -10,6 +10,7 @@ class CredentialAccessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_credential_values(credential_name: str) -> dict:
|
def get_credential_values(credential_name: str) -> dict:
|
||||||
"""Safe accessor for credentials."""
|
"""Safe accessor for credentials."""
|
||||||
|
|
||||||
if not litellm.credential_list:
|
if not litellm.credential_list:
|
||||||
return {}
|
return {}
|
||||||
for credential in litellm.credential_list:
|
for credential in litellm.credential_list:
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional, Tuple
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import REPLICATE_MODEL_NAME_WITH_ID_LENGTH
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
|
|
||||||
from ..types.router import LiteLLM_Params
|
from ..types.router import LiteLLM_Params
|
||||||
|
@ -256,10 +257,13 @@ def get_llm_provider( # noqa: PLR0915
|
||||||
elif model in litellm.cohere_chat_models:
|
elif model in litellm.cohere_chat_models:
|
||||||
custom_llm_provider = "cohere_chat"
|
custom_llm_provider = "cohere_chat"
|
||||||
## replicate
|
## replicate
|
||||||
elif model in litellm.replicate_models or (":" in model and len(model) > 64):
|
elif model in litellm.replicate_models or (
|
||||||
|
":" in model and len(model) > REPLICATE_MODEL_NAME_WITH_ID_LENGTH
|
||||||
|
):
|
||||||
model_parts = model.split(":")
|
model_parts = model.split(":")
|
||||||
if (
|
if (
|
||||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
len(model_parts) > 1
|
||||||
|
and len(model_parts[1]) == REPLICATE_MODEL_NAME_WITH_ID_LENGTH
|
||||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||||
custom_llm_provider = "replicate"
|
custom_llm_provider = "replicate"
|
||||||
elif model in litellm.replicate_models:
|
elif model in litellm.replicate_models:
|
||||||
|
|
|
@ -120,7 +120,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
elif custom_llm_provider == "replicate":
|
elif custom_llm_provider == "replicate":
|
||||||
return litellm.ReplicateConfig().get_supported_openai_params(model=model)
|
return litellm.ReplicateConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
return litellm.HuggingfaceConfig().get_supported_openai_params(model=model)
|
return litellm.HuggingFaceChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "jina_ai":
|
elif custom_llm_provider == "jina_ai":
|
||||||
if request_type == "embeddings":
|
if request_type == "embeddings":
|
||||||
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
|
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
|
||||||
|
|
|
@ -28,6 +28,10 @@ from litellm._logging import _is_debugging_on, verbose_logger
|
||||||
from litellm.batches.batch_utils import _handle_completed_batch
|
from litellm.batches.batch_utils import _handle_completed_batch
|
||||||
from litellm.caching.caching import DualCache, InMemoryCache
|
from litellm.caching.caching import DualCache, InMemoryCache
|
||||||
from litellm.caching.caching_handler import LLMCachingHandler
|
from litellm.caching.caching_handler import LLMCachingHandler
|
||||||
|
from litellm.constants import (
|
||||||
|
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
||||||
|
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
||||||
|
)
|
||||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||||
from litellm.integrations.arize.arize import ArizeLogger
|
from litellm.integrations.arize.arize import ArizeLogger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
@ -453,8 +457,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
|
prompt_management_logger: Optional[CustomLogger] = None,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
custom_logger = (
|
||||||
|
prompt_management_logger
|
||||||
|
or self.get_custom_logger_for_prompt_management(model)
|
||||||
|
)
|
||||||
if custom_logger:
|
if custom_logger:
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
|
@ -3745,9 +3753,12 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
response_cost_failure_debug_info=None,
|
response_cost_failure_debug_info=None,
|
||||||
status=str("success"),
|
status=str("success"),
|
||||||
total_tokens=int(30),
|
total_tokens=int(
|
||||||
prompt_tokens=int(20),
|
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT
|
||||||
completion_tokens=int(10),
|
+ DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT
|
||||||
|
),
|
||||||
|
prompt_tokens=int(DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT),
|
||||||
|
completion_tokens=int(DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT),
|
||||||
startTime=start_time,
|
startTime=start_time,
|
||||||
endTime=end_time,
|
endTime=end_time,
|
||||||
completionStartTime=completion_start_time,
|
completionStartTime=completion_start_time,
|
||||||
|
|
|
@ -5,6 +5,7 @@ Helper utilities for tracking the cost of built-in tools.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import OPENAI_FILE_SEARCH_COST_PER_1K_CALLS
|
||||||
from litellm.types.llms.openai import FileSearchTool, WebSearchOptions
|
from litellm.types.llms.openai import FileSearchTool, WebSearchOptions
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
@ -132,7 +133,7 @@ class StandardBuiltInToolCostTracking:
|
||||||
"""
|
"""
|
||||||
if file_search is None:
|
if file_search is None:
|
||||||
return 0.0
|
return 0.0
|
||||||
return 2.5 / 1000
|
return OPENAI_FILE_SEARCH_COST_PER_1K_CALLS
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chat_completion_response_includes_annotations(
|
def chat_completion_response_includes_annotations(
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
from litellm.types.llms.databricks import DatabricksTool
|
||||||
from litellm.types.llms.openai import ChatCompletionThinkingBlock
|
from litellm.types.llms.openai import ChatCompletionThinkingBlock
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionDeltaToolCall,
|
ChatCompletionDeltaToolCall,
|
||||||
|
@ -35,6 +36,25 @@ from litellm.types.utils import (
|
||||||
from .get_headers import get_response_headers
|
from .get_headers import get_response_headers
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tool_call_to_json_mode(
|
||||||
|
tool_calls: List[ChatCompletionMessageToolCall],
|
||||||
|
convert_tool_call_to_json_mode: bool,
|
||||||
|
) -> Tuple[Optional[Message], Optional[str]]:
|
||||||
|
if _should_convert_tool_call_to_json_mode(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
|
||||||
|
):
|
||||||
|
# to support 'json_schema' logic on older models
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
message = litellm.Message(content=json_mode_content_str)
|
||||||
|
finish_reason = "stop"
|
||||||
|
return message, finish_reason
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
async def convert_to_streaming_response_async(response_object: Optional[dict] = None):
|
async def convert_to_streaming_response_async(response_object: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
Asynchronously converts a response object to a streaming response.
|
Asynchronously converts a response object to a streaming response.
|
||||||
|
@ -335,21 +355,14 @@ class LiteLLMResponseObjectHandler:
|
||||||
Only supported for HF TGI models
|
Only supported for HF TGI models
|
||||||
"""
|
"""
|
||||||
transformed_logprobs: Optional[TextCompletionLogprobs] = None
|
transformed_logprobs: Optional[TextCompletionLogprobs] = None
|
||||||
if custom_llm_provider == "huggingface":
|
|
||||||
# only supported for TGI models
|
|
||||||
try:
|
|
||||||
raw_response = response._hidden_params.get("original_response", None)
|
|
||||||
transformed_logprobs = litellm.huggingface._transform_logprobs(
|
|
||||||
hf_response=raw_response
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
|
|
||||||
|
|
||||||
return transformed_logprobs
|
return transformed_logprobs
|
||||||
|
|
||||||
|
|
||||||
def _should_convert_tool_call_to_json_mode(
|
def _should_convert_tool_call_to_json_mode(
|
||||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None,
|
tool_calls: Optional[
|
||||||
|
Union[List[ChatCompletionMessageToolCall], List[DatabricksTool]]
|
||||||
|
] = None,
|
||||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List, Literal, Optional, Union, cast
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionFileObject,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||||
|
@ -34,7 +35,7 @@ def handle_messages_with_content_list_to_str_conversion(
|
||||||
|
|
||||||
|
|
||||||
def strip_name_from_messages(
|
def strip_name_from_messages(
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues], allowed_name_roles: List[str] = ["user"]
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
"""
|
"""
|
||||||
Removes 'name' from messages
|
Removes 'name' from messages
|
||||||
|
@ -43,7 +44,7 @@ def strip_name_from_messages(
|
||||||
for message in messages:
|
for message in messages:
|
||||||
msg_role = message.get("role")
|
msg_role = message.get("role")
|
||||||
msg_copy = message.copy()
|
msg_copy = message.copy()
|
||||||
if msg_role == "user":
|
if msg_role not in allowed_name_roles:
|
||||||
msg_copy.pop("name", None) # type: ignore
|
msg_copy.pop("name", None) # type: ignore
|
||||||
new_messages.append(msg_copy)
|
new_messages.append(msg_copy)
|
||||||
return new_messages
|
return new_messages
|
||||||
|
@ -292,3 +293,58 @@ def get_completion_messages(
|
||||||
messages, assistant_continue_message, ensure_alternating_roles
|
messages, assistant_continue_message, ensure_alternating_roles
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Gets file ids from messages
|
||||||
|
"""
|
||||||
|
file_ids = []
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(file_id)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
|
||||||
|
def update_messages_with_model_file_ids(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
model_id: str,
|
||||||
|
model_file_id_mapping: Dict[str, Dict[str, str]],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Updates messages with model file ids.
|
||||||
|
|
||||||
|
model_file_id_mapping: Dict[str, Dict[str, str]] = {
|
||||||
|
"litellm_proxy/file_id": {
|
||||||
|
"model_id": "provider_file_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
provider_file_id = (
|
||||||
|
model_file_id_mapping.get(file_id, {}).get(model_id)
|
||||||
|
or file_id
|
||||||
|
)
|
||||||
|
file_object_file_field["file_id"] = provider_file_id
|
||||||
|
return messages
|
||||||
|
|
|
@ -1300,20 +1300,37 @@ def convert_to_anthropic_tool_invoke(
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
anthropic_tool_invoke = [
|
anthropic_tool_invoke = []
|
||||||
AnthropicMessagesToolUseParam(
|
|
||||||
|
for tool in tool_calls:
|
||||||
|
if not get_attribute_or_key(tool, "type") == "function":
|
||||||
|
continue
|
||||||
|
|
||||||
|
_anthropic_tool_use_param = AnthropicMessagesToolUseParam(
|
||||||
type="tool_use",
|
type="tool_use",
|
||||||
id=get_attribute_or_key(tool, "id"),
|
id=cast(str, get_attribute_or_key(tool, "id")),
|
||||||
name=get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"),
|
name=cast(
|
||||||
|
str,
|
||||||
|
get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"),
|
||||||
|
),
|
||||||
input=json.loads(
|
input=json.loads(
|
||||||
get_attribute_or_key(
|
get_attribute_or_key(
|
||||||
get_attribute_or_key(tool, "function"), "arguments"
|
get_attribute_or_key(tool, "function"), "arguments"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for tool in tool_calls
|
|
||||||
if get_attribute_or_key(tool, "type") == "function"
|
_content_element = add_cache_control_to_content(
|
||||||
]
|
anthropic_content_element=_anthropic_tool_use_param,
|
||||||
|
orignal_content_element=dict(tool),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "cache_control" in _content_element:
|
||||||
|
_anthropic_tool_use_param["cache_control"] = _content_element[
|
||||||
|
"cache_control"
|
||||||
|
]
|
||||||
|
|
||||||
|
anthropic_tool_invoke.append(_anthropic_tool_use_param)
|
||||||
|
|
||||||
return anthropic_tool_invoke
|
return anthropic_tool_invoke
|
||||||
|
|
||||||
|
@ -1324,6 +1341,7 @@ def add_cache_control_to_content(
|
||||||
AnthropicMessagesImageParam,
|
AnthropicMessagesImageParam,
|
||||||
AnthropicMessagesTextParam,
|
AnthropicMessagesTextParam,
|
||||||
AnthropicMessagesDocumentParam,
|
AnthropicMessagesDocumentParam,
|
||||||
|
AnthropicMessagesToolUseParam,
|
||||||
ChatCompletionThinkingBlock,
|
ChatCompletionThinkingBlock,
|
||||||
],
|
],
|
||||||
orignal_content_element: Union[dict, AllMessageValues],
|
orignal_content_element: Union[dict, AllMessageValues],
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionAssistantContentValue,
|
ChatCompletionAssistantContentValue,
|
||||||
|
@ -9,7 +9,9 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionAudioResponse,
|
ChatCompletionAudioResponse,
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
|
Choices,
|
||||||
CompletionTokensDetails,
|
CompletionTokensDetails,
|
||||||
|
CompletionTokensDetailsWrapper,
|
||||||
Function,
|
Function,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
@ -203,14 +205,14 @@ class ChunkProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_combined_content(
|
def get_combined_content(
|
||||||
self, chunks: List[Dict[str, Any]]
|
self, chunks: List[Dict[str, Any]], delta_key: str = "content"
|
||||||
) -> ChatCompletionAssistantContentValue:
|
) -> ChatCompletionAssistantContentValue:
|
||||||
content_list: List[str] = []
|
content_list: List[str] = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
delta = choice.get("delta", {})
|
delta = choice.get("delta", {})
|
||||||
content = delta.get("content", "")
|
content = delta.get(delta_key, "")
|
||||||
if content is None:
|
if content is None:
|
||||||
continue # openai v1.0.0 sets content = None for chunks
|
continue # openai v1.0.0 sets content = None for chunks
|
||||||
content_list.append(content)
|
content_list.append(content)
|
||||||
|
@ -221,6 +223,11 @@ class ChunkProcessor:
|
||||||
# Update the "content" field within the response dictionary
|
# Update the "content" field within the response dictionary
|
||||||
return combined_content
|
return combined_content
|
||||||
|
|
||||||
|
def get_combined_reasoning_content(
|
||||||
|
self, chunks: List[Dict[str, Any]]
|
||||||
|
) -> ChatCompletionAssistantContentValue:
|
||||||
|
return self.get_combined_content(chunks, delta_key="reasoning_content")
|
||||||
|
|
||||||
def get_combined_audio_content(
|
def get_combined_audio_content(
|
||||||
self, chunks: List[Dict[str, Any]]
|
self, chunks: List[Dict[str, Any]]
|
||||||
) -> ChatCompletionAudioResponse:
|
) -> ChatCompletionAudioResponse:
|
||||||
|
@ -296,12 +303,27 @@ class ChunkProcessor:
|
||||||
"prompt_tokens_details": prompt_tokens_details,
|
"prompt_tokens_details": prompt_tokens_details,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, response: ModelResponse) -> int:
|
||||||
|
reasoning_tokens = 0
|
||||||
|
for choice in response.choices:
|
||||||
|
if (
|
||||||
|
hasattr(cast(Choices, choice).message, "reasoning_content")
|
||||||
|
and cast(Choices, choice).message.reasoning_content is not None
|
||||||
|
):
|
||||||
|
reasoning_tokens += token_counter(
|
||||||
|
text=cast(Choices, choice).message.reasoning_content,
|
||||||
|
count_response_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return reasoning_tokens
|
||||||
|
|
||||||
def calculate_usage(
|
def calculate_usage(
|
||||||
self,
|
self,
|
||||||
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
||||||
model: str,
|
model: str,
|
||||||
completion_output: str,
|
completion_output: str,
|
||||||
messages: Optional[List] = None,
|
messages: Optional[List] = None,
|
||||||
|
reasoning_tokens: Optional[int] = None,
|
||||||
) -> Usage:
|
) -> Usage:
|
||||||
"""
|
"""
|
||||||
Calculate usage for the given chunks.
|
Calculate usage for the given chunks.
|
||||||
|
@ -382,6 +404,19 @@ class ChunkProcessor:
|
||||||
) # for anthropic
|
) # for anthropic
|
||||||
if completion_tokens_details is not None:
|
if completion_tokens_details is not None:
|
||||||
returned_usage.completion_tokens_details = completion_tokens_details
|
returned_usage.completion_tokens_details = completion_tokens_details
|
||||||
|
|
||||||
|
if reasoning_tokens is not None:
|
||||||
|
if returned_usage.completion_tokens_details is None:
|
||||||
|
returned_usage.completion_tokens_details = (
|
||||||
|
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
returned_usage.completion_tokens_details is not None
|
||||||
|
and returned_usage.completion_tokens_details.reasoning_tokens is None
|
||||||
|
):
|
||||||
|
returned_usage.completion_tokens_details.reasoning_tokens = (
|
||||||
|
reasoning_tokens
|
||||||
|
)
|
||||||
if prompt_tokens_details is not None:
|
if prompt_tokens_details is not None:
|
||||||
returned_usage.prompt_tokens_details = prompt_tokens_details
|
returned_usage.prompt_tokens_details = prompt_tokens_details
|
||||||
|
|
||||||
|
|
|
@ -214,10 +214,7 @@ class CustomStreamWrapper:
|
||||||
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
||||||
"""
|
"""
|
||||||
hold = False
|
hold = False
|
||||||
if (
|
if self.custom_llm_provider != "sagemaker":
|
||||||
self.custom_llm_provider != "huggingface"
|
|
||||||
and self.custom_llm_provider != "sagemaker"
|
|
||||||
):
|
|
||||||
return hold, chunk
|
return hold, chunk
|
||||||
|
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
|
@ -290,49 +287,6 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_huggingface_chunk(self, chunk):
|
|
||||||
try:
|
|
||||||
if not isinstance(chunk, str):
|
|
||||||
chunk = chunk.decode(
|
|
||||||
"utf-8"
|
|
||||||
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
|
||||||
print_verbose(f"chunk: {chunk}")
|
|
||||||
if chunk.startswith("data:"):
|
|
||||||
data_json = json.loads(chunk[5:])
|
|
||||||
print_verbose(f"data json: {data_json}")
|
|
||||||
if "token" in data_json and "text" in data_json["token"]:
|
|
||||||
text = data_json["token"]["text"]
|
|
||||||
if data_json.get("details", False) and data_json["details"].get(
|
|
||||||
"finish_reason", False
|
|
||||||
):
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = data_json["details"]["finish_reason"]
|
|
||||||
elif data_json.get(
|
|
||||||
"generated_text", False
|
|
||||||
): # if full generated text exists, then stream is complete
|
|
||||||
text = "" # don't return the final bos token
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = "stop"
|
|
||||||
elif data_json.get("error", False):
|
|
||||||
raise Exception(data_json.get("error"))
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
elif "error" in chunk:
|
|
||||||
raise ValueError(chunk)
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def handle_ai21_chunk(self, chunk): # fake streaming
|
def handle_ai21_chunk(self, chunk): # fake streaming
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
data_json = json.loads(chunk)
|
data_json = json.loads(chunk)
|
||||||
|
@ -1049,11 +1003,6 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
|
||||||
response_obj = self.handle_huggingface_chunk(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
|
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
|
||||||
response_obj = self.handle_predibase_chunk(chunk)
|
response_obj = self.handle_predibase_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
|
@ -11,6 +11,10 @@ from litellm.constants import (
|
||||||
DEFAULT_IMAGE_HEIGHT,
|
DEFAULT_IMAGE_HEIGHT,
|
||||||
DEFAULT_IMAGE_TOKEN_COUNT,
|
DEFAULT_IMAGE_TOKEN_COUNT,
|
||||||
DEFAULT_IMAGE_WIDTH,
|
DEFAULT_IMAGE_WIDTH,
|
||||||
|
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES,
|
||||||
|
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES,
|
||||||
|
MAX_TILE_HEIGHT,
|
||||||
|
MAX_TILE_WIDTH,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||||
|
|
||||||
|
@ -97,11 +101,14 @@ def resize_image_high_res(
|
||||||
height: int,
|
height: int,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
# Maximum dimensions for high res mode
|
# Maximum dimensions for high res mode
|
||||||
max_short_side = 768
|
max_short_side = MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
|
||||||
max_long_side = 2000
|
max_long_side = MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES
|
||||||
|
|
||||||
# Return early if no resizing is needed
|
# Return early if no resizing is needed
|
||||||
if width <= 768 and height <= 768:
|
if (
|
||||||
|
width <= MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
|
||||||
|
and height <= MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES
|
||||||
|
):
|
||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
# Determine the longer and shorter sides
|
# Determine the longer and shorter sides
|
||||||
|
@ -132,7 +139,10 @@ def resize_image_high_res(
|
||||||
|
|
||||||
# Test the function with the given example
|
# Test the function with the given example
|
||||||
def calculate_tiles_needed(
|
def calculate_tiles_needed(
|
||||||
resized_width, resized_height, tile_width=512, tile_height=512
|
resized_width,
|
||||||
|
resized_height,
|
||||||
|
tile_width=MAX_TILE_WIDTH,
|
||||||
|
tile_height=MAX_TILE_HEIGHT,
|
||||||
):
|
):
|
||||||
tiles_across = (resized_width + tile_width - 1) // tile_width
|
tiles_across = (resized_width + tile_width - 1) // tile_width
|
||||||
tiles_down = (resized_height + tile_height - 1) // tile_height
|
tiles_down = (resized_height + tile_height - 1) // tile_height
|
||||||
|
|
|
@ -21,7 +21,6 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AnthropicChatCompletionUsageBlock,
|
|
||||||
ContentBlockDelta,
|
ContentBlockDelta,
|
||||||
ContentBlockStart,
|
ContentBlockStart,
|
||||||
ContentBlockStop,
|
ContentBlockStop,
|
||||||
|
@ -32,13 +31,13 @@ from litellm.types.llms.anthropic import (
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionThinkingBlock,
|
ChatCompletionThinkingBlock,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionUsageBlock,
|
|
||||||
)
|
)
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
Delta,
|
Delta,
|
||||||
GenericStreamingChunk,
|
GenericStreamingChunk,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
StreamingChoices,
|
StreamingChoices,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||||
|
|
||||||
|
@ -487,10 +486,8 @@ class ModelResponseIterator:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_usage(
|
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
usage_block = Usage(
|
||||||
) -> AnthropicChatCompletionUsageBlock:
|
|
||||||
usage_block = AnthropicChatCompletionUsageBlock(
|
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||||
|
@ -581,7 +578,7 @@ class ModelResponseIterator:
|
||||||
text = ""
|
text = ""
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
usage: Optional[Usage] = None
|
||||||
provider_specific_fields: Dict[str, Any] = {}
|
provider_specific_fields: Dict[str, Any] = {}
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||||
|
|
|
@ -5,7 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
from litellm.constants import (
|
||||||
|
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS,
|
||||||
|
RESPONSE_FORMAT_TOOL_NAME,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||||
|
@ -30,9 +33,16 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import CompletionTokensDetailsWrapper
|
||||||
from litellm.types.utils import Message as LitellmMessage
|
from litellm.types.utils import Message as LitellmMessage
|
||||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||||
from litellm.utils import ModelResponse, Usage, add_dummy_tool, has_tool_call_blocks
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
add_dummy_tool,
|
||||||
|
has_tool_call_blocks,
|
||||||
|
token_counter,
|
||||||
|
)
|
||||||
|
|
||||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
from ..common_utils import AnthropicError, process_anthropic_headers
|
||||||
|
|
||||||
|
@ -53,7 +63,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
|
|
||||||
max_tokens: Optional[
|
max_tokens: Optional[
|
||||||
int
|
int
|
||||||
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||||
stop_sequences: Optional[list] = None
|
stop_sequences: Optional[list] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
|
@ -65,7 +75,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
self,
|
self,
|
||||||
max_tokens: Optional[
|
max_tokens: Optional[
|
||||||
int
|
int
|
||||||
] = 4096, # You can pass in a value yourself or use the default value 4096
|
] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS, # You can pass in a value yourself or use the default value 4096
|
||||||
stop_sequences: Optional[list] = None,
|
stop_sequences: Optional[list] = None,
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
|
@ -309,6 +319,33 @@ class AnthropicConfig(BaseConfig):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
|
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
|
||||||
|
|
||||||
|
def map_response_format_to_anthropic_tool(
|
||||||
|
self, value: Optional[dict], optional_params: dict, is_thinking_enabled: bool
|
||||||
|
) -> Optional[AnthropicMessagesTool]:
|
||||||
|
ignore_response_format_types = ["text"]
|
||||||
|
if (
|
||||||
|
value is None or value["type"] in ignore_response_format_types
|
||||||
|
): # value is a no-op
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_schema: Optional[dict] = None
|
||||||
|
if "response_schema" in value:
|
||||||
|
json_schema = value["response_schema"]
|
||||||
|
elif "json_schema" in value:
|
||||||
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
"""
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tool = self._create_json_tool_call_for_response_format(
|
||||||
|
json_schema=json_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _tool
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
@ -352,34 +389,18 @@ class AnthropicConfig(BaseConfig):
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "response_format" and isinstance(value, dict):
|
if param == "response_format" and isinstance(value, dict):
|
||||||
ignore_response_format_types = ["text"]
|
_tool = self.map_response_format_to_anthropic_tool(
|
||||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
value, optional_params, is_thinking_enabled
|
||||||
|
)
|
||||||
|
if _tool is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
json_schema: Optional[dict] = None
|
|
||||||
if "response_schema" in value:
|
|
||||||
json_schema = value["response_schema"]
|
|
||||||
elif "json_schema" in value:
|
|
||||||
json_schema = value["json_schema"]["schema"]
|
|
||||||
"""
|
|
||||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
|
||||||
- You usually want to provide a single tool
|
|
||||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
|
||||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not is_thinking_enabled:
|
if not is_thinking_enabled:
|
||||||
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
|
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
_tool = self._create_json_tool_call_for_response_format(
|
|
||||||
json_schema=json_schema,
|
|
||||||
)
|
|
||||||
optional_params = self._add_tools_to_optional_params(
|
optional_params = self._add_tools_to_optional_params(
|
||||||
optional_params=optional_params, tools=[_tool]
|
optional_params=optional_params, tools=[_tool]
|
||||||
)
|
)
|
||||||
|
|
||||||
optional_params["json_mode"] = True
|
|
||||||
if param == "user":
|
if param == "user":
|
||||||
optional_params["metadata"] = {"user_id": value}
|
optional_params["metadata"] = {"user_id": value}
|
||||||
if param == "thinking":
|
if param == "thinking":
|
||||||
|
@ -769,6 +790,15 @@ class AnthropicConfig(BaseConfig):
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
cached_tokens=cache_read_input_tokens
|
cached_tokens=cache_read_input_tokens
|
||||||
)
|
)
|
||||||
|
completion_token_details = (
|
||||||
|
CompletionTokensDetailsWrapper(
|
||||||
|
reasoning_tokens=token_counter(
|
||||||
|
text=reasoning_content, count_response_tokens=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reasoning_content
|
||||||
|
else None
|
||||||
|
)
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
@ -777,6 +807,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
|
completion_tokens_details=completion_token_details,
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import DEFAULT_MAX_TOKENS
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -65,7 +66,9 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
|
max_tokens_to_sample: Optional[
|
||||||
|
int
|
||||||
|
] = DEFAULT_MAX_TOKENS, # anthropic requires a default
|
||||||
stop_sequences: Optional[list] = None,
|
stop_sequences: Optional[list] = None,
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
|
|
|
@ -7,7 +7,7 @@ import httpx # type: ignore
|
||||||
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
|
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
from litellm.constants import AZURE_OPERATION_POLLING_TIMEOUT, DEFAULT_MAX_RETRIES
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
@ -857,7 +857,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
|
|
||||||
await response.aread()
|
await response.aread()
|
||||||
|
|
||||||
timeout_secs: int = 120
|
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if "status" not in response.json():
|
if "status" not in response.json():
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -955,7 +955,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
|
|
||||||
response.read()
|
response.read()
|
||||||
|
|
||||||
timeout_secs: int = 120
|
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if "status" not in response.json():
|
if "status" not in response.json():
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -7,6 +7,10 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_to_azure_openai_messages,
|
convert_to_azure_openai_messages,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.types.llms.azure import (
|
||||||
|
API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT,
|
||||||
|
API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT,
|
||||||
|
)
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
from litellm.utils import supports_response_schema
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
@ -123,7 +127,10 @@ class AzureOpenAIConfig(BaseConfig):
|
||||||
- check if api_version is supported for response_format
|
- check if api_version is supported for response_format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8
|
is_supported = (
|
||||||
|
int(api_version_year) <= API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT
|
||||||
|
and int(api_version_month) >= API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT
|
||||||
|
)
|
||||||
|
|
||||||
return is_supported
|
return is_supported
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ Translations handled by LiteLLM:
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.utils import get_model_info
|
from litellm.utils import get_model_info
|
||||||
|
@ -22,6 +23,27 @@ from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
|
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
"""
|
||||||
|
Get the supported OpenAI params for the Azure O-Series models
|
||||||
|
"""
|
||||||
|
all_openai_params = litellm.OpenAIGPTConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
non_supported_params = [
|
||||||
|
"logprobs",
|
||||||
|
"top_p",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"top_logprobs",
|
||||||
|
]
|
||||||
|
|
||||||
|
o_series_only_param = ["reasoning_effort"]
|
||||||
|
all_openai_params.extend(o_series_only_param)
|
||||||
|
return [
|
||||||
|
param for param in all_openai_params if param not in non_supported_params
|
||||||
|
]
|
||||||
|
|
||||||
def should_fake_stream(
|
def should_fake_stream(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
|
|
@ -28,11 +28,11 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
self,
|
self,
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
openai_client: AsyncAzureOpenAI,
|
openai_client: AsyncAzureOpenAI,
|
||||||
) -> FileObject:
|
) -> OpenAIFileObject:
|
||||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||||
response = await openai_client.files.create(**create_file_data)
|
response = await openai_client.files.create(**create_file_data)
|
||||||
verbose_logger.debug("create_file_response=%s", response)
|
verbose_logger.debug("create_file_response=%s", response)
|
||||||
return response
|
return OpenAIFileObject(**response.model_dump())
|
||||||
|
|
||||||
def create_file(
|
def create_file(
|
||||||
self,
|
self,
|
||||||
|
@ -66,7 +66,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||||
)
|
)
|
||||||
return self.acreate_file( # type: ignore
|
return self.acreate_file(
|
||||||
create_file_data=create_file_data, openai_client=openai_client
|
create_file_data=create_file_data, openai_client=openai_client
|
||||||
)
|
)
|
||||||
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import json
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,6 +34,18 @@ class BaseModelResponseIterator:
|
||||||
self, str_line: str
|
self, str_line: str
|
||||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||||
# chunk is a str at this point
|
# chunk is a str at this point
|
||||||
|
|
||||||
|
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
|
||||||
|
str_line
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if stripped_chunk is not None:
|
||||||
|
stripped_json_chunk: Optional[dict] = json.loads(stripped_chunk)
|
||||||
|
else:
|
||||||
|
stripped_json_chunk = None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
stripped_json_chunk = None
|
||||||
|
|
||||||
if "[DONE]" in str_line:
|
if "[DONE]" in str_line:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
|
@ -42,9 +55,8 @@ class BaseModelResponseIterator:
|
||||||
index=0,
|
index=0,
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
)
|
)
|
||||||
elif str_line.startswith("data:"):
|
elif stripped_json_chunk:
|
||||||
data_json = json.loads(str_line[5:])
|
return self.chunk_parser(chunk=stripped_json_chunk)
|
||||||
return self.chunk_parser(chunk=data_json)
|
|
||||||
else:
|
else:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
|
@ -85,6 +97,7 @@ class BaseModelResponseIterator:
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -99,7 +112,9 @@ class BaseModelResponseIterator:
|
||||||
str_line = str_line[index:]
|
str_line = str_line[index:]
|
||||||
|
|
||||||
# chunk is a str at this point
|
# chunk is a str at this point
|
||||||
return self._handle_string_chunk(str_line=str_line)
|
chunk = self._handle_string_chunk(str_line=str_line)
|
||||||
|
|
||||||
|
return chunk
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
@ -3,6 +3,7 @@ Utility functions for base LLM classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Type, Union
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
|
@ -10,8 +11,8 @@ from openai.lib import _parsing, _pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
|
||||||
from litellm.types.utils import ProviderSpecificModelInfo
|
from litellm.types.utils import Message, ProviderSpecificModelInfo
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMModelInfo(ABC):
|
class BaseLLMModelInfo(ABC):
|
||||||
|
@ -55,6 +56,32 @@ class BaseLLMModelInfo(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool_response_to_message(
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk],
|
||||||
|
) -> Optional[Message]:
|
||||||
|
"""
|
||||||
|
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||||
|
|
||||||
|
"""
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
|
||||||
|
try:
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
args = json.loads(json_mode_content_str)
|
||||||
|
if isinstance(args, dict) and (values := args.get("values")) is not None:
|
||||||
|
_message = Message(content=json.dumps(values))
|
||||||
|
return _message
|
||||||
|
else:
|
||||||
|
# a lot of the times the `values` key is not present in the tool response
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
_message = Message(content=json.dumps(args))
|
||||||
|
return _message
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# json decode error does occur, return the original tool response str
|
||||||
|
return Message(content=json_mode_content_str)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _dict_to_response_format_helper(
|
def _dict_to_response_format_helper(
|
||||||
response_format: dict, ref_template: Optional[str] = None
|
response_format: dict, ref_template: Optional[str] = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL
|
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL, BEDROCK_MAX_POLICY_SIZE
|
||||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
|
||||||
|
@ -381,7 +381,7 @@ class BaseAWSLLM:
|
||||||
"region_name": aws_region_name,
|
"region_name": aws_region_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
if sts_response["PackedPolicySize"] > 75:
|
if sts_response["PackedPolicySize"] > BEDROCK_MAX_POLICY_SIZE:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -368,6 +368,7 @@ class BaseLLMHTTPHandler:
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -420,6 +421,7 @@ class BaseLLMHTTPHandler:
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
client: Optional[HTTPHandler] = None,
|
client: Optional[HTTPHandler] = None,
|
||||||
|
json_mode: bool = False,
|
||||||
) -> Tuple[Any, dict]:
|
) -> Tuple[Any, dict]:
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
sync_httpx_client = _get_httpx_client(
|
sync_httpx_client = _get_httpx_client(
|
||||||
|
@ -447,11 +449,15 @@ class BaseLLMHTTPHandler:
|
||||||
|
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
completion_stream = provider_config.get_model_response_iterator(
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
streaming_response=response.json(), sync_stream=True
|
streaming_response=response.json(),
|
||||||
|
sync_stream=True,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_stream = provider_config.get_model_response_iterator(
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
streaming_response=response.iter_lines(),
|
||||||
|
sync_stream=True,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
|
|
@ -1,84 +0,0 @@
|
||||||
"""
|
|
||||||
Handles the chat completion request for Databricks
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Union, cast
|
|
||||||
|
|
||||||
from httpx._config import Timeout
|
|
||||||
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
|
||||||
from litellm.types.utils import CustomStreamingDecoder
|
|
||||||
from litellm.utils import ModelResponse
|
|
||||||
|
|
||||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
|
||||||
from ..common_utils import DatabricksBase
|
|
||||||
from .transformation import DatabricksConfig
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
custom_llm_provider: str,
|
|
||||||
custom_prompt_dict: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key: Optional[str],
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
acompletion=None,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers: Optional[dict] = None,
|
|
||||||
timeout: Optional[Union[float, Timeout]] = None,
|
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
||||||
custom_endpoint: Optional[bool] = None,
|
|
||||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
|
||||||
fake_stream: bool = False,
|
|
||||||
):
|
|
||||||
messages = DatabricksConfig()._transform_messages(
|
|
||||||
messages=cast(List[AllMessageValues], messages), model=model
|
|
||||||
)
|
|
||||||
api_base, headers = self.databricks_validate_environment(
|
|
||||||
api_base=api_base,
|
|
||||||
api_key=api_key,
|
|
||||||
endpoint_type="chat_completions",
|
|
||||||
custom_endpoint=custom_endpoint,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
if optional_params.get("stream") is True:
|
|
||||||
fake_stream = DatabricksConfig()._should_fake_stream(optional_params)
|
|
||||||
else:
|
|
||||||
fake_stream = False
|
|
||||||
|
|
||||||
return super().completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
api_base=api_base,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=api_key,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
acompletion=acompletion,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=headers,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
custom_endpoint=True,
|
|
||||||
streaming_decoder=streaming_decoder,
|
|
||||||
fake_stream=fake_stream,
|
|
||||||
)
|
|
|
@ -2,21 +2,68 @@
|
||||||
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
|
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||||
|
_handle_invalid_parallel_tool_calls,
|
||||||
|
_should_convert_tool_call_to_json_mode,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
handle_messages_with_content_list_to_str_conversion,
|
handle_messages_with_content_list_to_str_conversion,
|
||||||
strip_name_from_messages,
|
strip_name_from_messages,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.llms.anthropic import AnthropicMessagesTool
|
||||||
|
from litellm.types.llms.databricks import (
|
||||||
|
AllDatabricksContentValues,
|
||||||
|
DatabricksChoice,
|
||||||
|
DatabricksFunction,
|
||||||
|
DatabricksResponse,
|
||||||
|
DatabricksTool,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionThinkingBlock,
|
||||||
|
ChatCompletionToolChoiceFunctionParam,
|
||||||
|
ChatCompletionToolChoiceObjectParam,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
Choices,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
|
ProviderField,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...anthropic.chat.transformation import AnthropicConfig
|
||||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
from ..common_utils import DatabricksBase, DatabricksException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
class DatabricksConfig(OpenAILikeChatConfig):
|
class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||||
"""
|
"""
|
||||||
|
@ -63,6 +110,39 @@ class DatabricksConfig(OpenAILikeChatConfig):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
api_base, headers = self.databricks_validate_environment(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint_type="chat_completions",
|
||||||
|
custom_endpoint=False,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
# Ensure Content-Type header is set
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
api_base = self._get_api_base(api_base)
|
||||||
|
complete_url = f"{api_base}/chat/completions"
|
||||||
|
return complete_url
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
|
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -75,8 +155,98 @@ class DatabricksConfig(OpenAILikeChatConfig):
|
||||||
"response_format",
|
"response_format",
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
|
"reasoning_effort",
|
||||||
|
"thinking",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def convert_anthropic_tool_to_databricks_tool(
|
||||||
|
self, tool: Optional[AnthropicMessagesTool]
|
||||||
|
) -> Optional[DatabricksTool]:
|
||||||
|
if tool is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return DatabricksTool(
|
||||||
|
type="function",
|
||||||
|
function=DatabricksFunction(
|
||||||
|
name=tool["name"],
|
||||||
|
parameters=cast(dict, tool.get("input_schema") or {}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_response_format_to_databricks_tool(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
value: Optional[dict],
|
||||||
|
optional_params: dict,
|
||||||
|
is_thinking_enabled: bool,
|
||||||
|
) -> Optional[DatabricksTool]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool = self.map_response_format_to_anthropic_tool(
|
||||||
|
value, optional_params, is_thinking_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool)
|
||||||
|
return databricks_tool
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
replace_max_completion_tokens_with_max_tokens: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||||
|
mapped_params = super().map_openai_params(
|
||||||
|
non_default_params, optional_params, model, drop_params
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"max_completion_tokens" in non_default_params
|
||||||
|
and replace_max_completion_tokens_with_max_tokens
|
||||||
|
):
|
||||||
|
mapped_params["max_tokens"] = non_default_params[
|
||||||
|
"max_completion_tokens"
|
||||||
|
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
|
||||||
|
mapped_params.pop("max_completion_tokens", None)
|
||||||
|
|
||||||
|
if "response_format" in non_default_params and "claude" in model:
|
||||||
|
_tool = self.map_response_format_to_databricks_tool(
|
||||||
|
model,
|
||||||
|
non_default_params["response_format"],
|
||||||
|
mapped_params,
|
||||||
|
is_thinking_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _tool is not None:
|
||||||
|
self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=[_tool]
|
||||||
|
)
|
||||||
|
optional_params["json_mode"] = True
|
||||||
|
if not is_thinking_enabled:
|
||||||
|
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolChoiceFunctionParam(
|
||||||
|
name=RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
),
|
||||||
|
)
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params.pop(
|
||||||
|
"response_format", None
|
||||||
|
) # unsupported for claude models - if json_schema -> convert to tool call
|
||||||
|
|
||||||
|
if "reasoning_effort" in non_default_params and "claude" in model:
|
||||||
|
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||||
|
non_default_params.get("reasoning_effort")
|
||||||
|
)
|
||||||
|
## handle thinking tokens
|
||||||
|
self.update_optional_params_with_thinking_tokens(
|
||||||
|
non_default_params=non_default_params, optional_params=mapped_params
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped_params
|
||||||
|
|
||||||
def _should_fake_stream(self, optional_params: dict) -> bool:
|
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Databricks doesn't support 'response_format' while streaming
|
Databricks doesn't support 'response_format' while streaming
|
||||||
|
@ -104,3 +274,259 @@ class DatabricksConfig(OpenAILikeChatConfig):
|
||||||
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
||||||
new_messages = strip_name_from_messages(new_messages)
|
new_messages = strip_name_from_messages(new_messages)
|
||||||
return super()._transform_messages(messages=new_messages, model=model)
|
return super()._transform_messages(messages=new_messages, model=model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_content_str(
|
||||||
|
content: Optional[AllDatabricksContentValues],
|
||||||
|
) -> Optional[str]:
|
||||||
|
if content is None:
|
||||||
|
return None
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
content_str = ""
|
||||||
|
for item in content:
|
||||||
|
if item["type"] == "text":
|
||||||
|
content_str += item["text"]
|
||||||
|
return content_str
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported content type: {type(content)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_reasoning_content(
|
||||||
|
content: Optional[AllDatabricksContentValues],
|
||||||
|
) -> Tuple[Optional[str], Optional[List[ChatCompletionThinkingBlock]]]:
|
||||||
|
"""
|
||||||
|
Extract and return the reasoning content and thinking blocks
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return None, None
|
||||||
|
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
|
if isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if item["type"] == "reasoning":
|
||||||
|
for sum in item["summary"]:
|
||||||
|
if reasoning_content is None:
|
||||||
|
reasoning_content = ""
|
||||||
|
reasoning_content += sum["text"]
|
||||||
|
thinking_block = ChatCompletionThinkingBlock(
|
||||||
|
type="thinking",
|
||||||
|
thinking=sum["text"],
|
||||||
|
signature=sum["signature"],
|
||||||
|
)
|
||||||
|
if thinking_blocks is None:
|
||||||
|
thinking_blocks = []
|
||||||
|
thinking_blocks.append(thinking_block)
|
||||||
|
return reasoning_content, thinking_blocks
|
||||||
|
|
||||||
|
def _transform_choices(
|
||||||
|
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None
|
||||||
|
) -> List[Choices]:
|
||||||
|
transformed_choices = []
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call]
|
||||||
|
tool_calls = choice["message"].get("tool_calls", None)
|
||||||
|
if tool_calls is not None:
|
||||||
|
_openai_tool_calls = []
|
||||||
|
for _tc in tool_calls:
|
||||||
|
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
|
||||||
|
_openai_tool_calls.append(_openai_tc)
|
||||||
|
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
|
||||||
|
_openai_tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
if fixed_tool_calls is not None:
|
||||||
|
tool_calls = fixed_tool_calls
|
||||||
|
|
||||||
|
translated_message: Optional[Message] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
if tool_calls and _should_convert_tool_call_to_json_mode(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
convert_tool_call_to_json_mode=json_mode,
|
||||||
|
):
|
||||||
|
# to support response_format on claude models
|
||||||
|
json_mode_content_str: Optional[str] = (
|
||||||
|
str(tool_calls[0]["function"].get("arguments", "")) or None
|
||||||
|
)
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
translated_message = Message(content=json_mode_content_str)
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
if translated_message is None:
|
||||||
|
## get the content str
|
||||||
|
content_str = DatabricksConfig.extract_content_str(
|
||||||
|
choice["message"]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
## get the reasoning content
|
||||||
|
(
|
||||||
|
reasoning_content,
|
||||||
|
thinking_blocks,
|
||||||
|
) = DatabricksConfig.extract_reasoning_content(
|
||||||
|
choice["message"].get("content")
|
||||||
|
)
|
||||||
|
|
||||||
|
translated_message = Message(
|
||||||
|
role="assistant",
|
||||||
|
content=content_str,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
thinking_blocks=thinking_blocks,
|
||||||
|
tool_calls=choice["message"].get("tool_calls"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if finish_reason is None:
|
||||||
|
finish_reason = choice["finish_reason"]
|
||||||
|
|
||||||
|
translated_choice = Choices(
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
index=choice["index"],
|
||||||
|
message=translated_message,
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformed_choices.append(translated_choice)
|
||||||
|
|
||||||
|
return transformed_choices
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = DatabricksResponse(**raw_response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
response_headers = getattr(raw_response, "headers", None)
|
||||||
|
raise DatabricksException(
|
||||||
|
message="Unable to get json response - {}, Original Response: {}".format(
|
||||||
|
str(e), raw_response.text
|
||||||
|
),
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
headers=response_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.model = completion_response["model"]
|
||||||
|
model_response.id = completion_response["id"]
|
||||||
|
model_response.created = completion_response["created"]
|
||||||
|
setattr(model_response, "usage", Usage(**completion_response["usage"]))
|
||||||
|
|
||||||
|
model_response.choices = self._transform_choices( # type: ignore
|
||||||
|
choices=completion_response["choices"],
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return DatabricksChatResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
super().__init__(streaming_response, sync_stream)
|
||||||
|
|
||||||
|
self.json_mode = json_mode
|
||||||
|
self._last_function_name = None # Track the last seen function name
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||||
|
try:
|
||||||
|
translated_choices = []
|
||||||
|
for choice in chunk["choices"]:
|
||||||
|
tool_calls = choice["delta"].get("tool_calls")
|
||||||
|
if tool_calls and self.json_mode:
|
||||||
|
# 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
# 2. If no function name, just args -> check last function name (saved via state variable)
|
||||||
|
# 3. Convert args to json
|
||||||
|
# 4. Convert json to message
|
||||||
|
# 5. Set content to message.content
|
||||||
|
# 6. Set tool_calls to None
|
||||||
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
from litellm.llms.base_llm.base_utils import (
|
||||||
|
_convert_tool_response_to_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this chunk has a function name
|
||||||
|
function_name = tool_calls[0].get("function", {}).get("name")
|
||||||
|
if function_name is not None:
|
||||||
|
self._last_function_name = function_name
|
||||||
|
|
||||||
|
# If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
# or this chunk has the matching function name
|
||||||
|
if (
|
||||||
|
self._last_function_name == RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
or function_name == RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
):
|
||||||
|
# Convert tool calls to message format
|
||||||
|
message = _convert_tool_response_to_message(tool_calls)
|
||||||
|
if message is not None:
|
||||||
|
if message.content == "{}": # empty json
|
||||||
|
message.content = ""
|
||||||
|
choice["delta"]["content"] = message.content
|
||||||
|
choice["delta"]["tool_calls"] = None
|
||||||
|
|
||||||
|
# extract the content str
|
||||||
|
content_str = DatabricksConfig.extract_content_str(
|
||||||
|
choice["delta"].get("content")
|
||||||
|
)
|
||||||
|
|
||||||
|
# extract the reasoning content
|
||||||
|
(
|
||||||
|
reasoning_content,
|
||||||
|
thinking_blocks,
|
||||||
|
) = DatabricksConfig.extract_reasoning_content(
|
||||||
|
choice["delta"]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
choice["delta"]["content"] = content_str
|
||||||
|
choice["delta"]["reasoning_content"] = reasoning_content
|
||||||
|
choice["delta"]["thinking_blocks"] = thinking_blocks
|
||||||
|
translated_choices.append(choice)
|
||||||
|
return ModelResponseStream(
|
||||||
|
id=chunk["id"],
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk["created"],
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=translated_choices,
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
raise DatabricksException(
|
||||||
|
message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
|
@ -1,9 +1,35 @@
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
from .exceptions import DatabricksError
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksException(BaseLLMException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatabricksBase:
|
class DatabricksBase:
|
||||||
|
def _get_api_base(self, api_base: Optional[str]) -> str:
|
||||||
|
if api_base is None:
|
||||||
|
try:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
|
||||||
|
databricks_client = WorkspaceClient()
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_base
|
||||||
|
except ImportError:
|
||||||
|
raise DatabricksException(
|
||||||
|
status_code=400,
|
||||||
|
message=(
|
||||||
|
"Either set the DATABRICKS_API_BASE and DATABRICKS_API_KEY environment variables, "
|
||||||
|
"or install the databricks-sdk Python library."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return api_base
|
||||||
|
|
||||||
def _get_databricks_credentials(
|
def _get_databricks_credentials(
|
||||||
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
|
@ -23,7 +49,7 @@ class DatabricksBase:
|
||||||
|
|
||||||
return api_base, headers
|
return api_base, headers
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise DatabricksError(
|
raise DatabricksException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message=(
|
message=(
|
||||||
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||||
|
@ -41,9 +67,9 @@ class DatabricksBase:
|
||||||
custom_endpoint: Optional[bool],
|
custom_endpoint: Optional[bool],
|
||||||
headers: Optional[dict],
|
headers: Optional[dict],
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
if api_key is None and headers is None:
|
if api_key is None and not headers: # handle empty headers
|
||||||
if custom_endpoint is not None:
|
if custom_endpoint is not None:
|
||||||
raise DatabricksError(
|
raise DatabricksException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||||
)
|
)
|
||||||
|
@ -54,7 +80,7 @@ class DatabricksBase:
|
||||||
|
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
if custom_endpoint:
|
if custom_endpoint:
|
||||||
raise DatabricksError(
|
raise DatabricksException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import MIN_NON_ZERO_TEMPERATURE
|
||||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ class DeepInfraConfig(OpenAIGPTConfig):
|
||||||
and value == 0
|
and value == 0
|
||||||
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
): # this model does no support temperature == 0
|
): # this model does no support temperature == 0
|
||||||
value = 0.0001 # close to 0
|
value = MIN_NON_ZERO_TEMPERATURE # close to 0
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
if (
|
if (
|
||||||
value != "auto" and value != "none"
|
value != "auto" and value != "none"
|
||||||
|
|
|
@ -4,6 +4,12 @@ For calculating cost of fireworks ai serverless inference models.
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from litellm.constants import (
|
||||||
|
FIREWORKS_AI_16_B,
|
||||||
|
FIREWORKS_AI_56_B_MOE,
|
||||||
|
FIREWORKS_AI_80_B,
|
||||||
|
FIREWORKS_AI_176_B_MOE,
|
||||||
|
)
|
||||||
from litellm.types.utils import Usage
|
from litellm.types.utils import Usage
|
||||||
from litellm.utils import get_model_info
|
from litellm.utils import get_model_info
|
||||||
|
|
||||||
|
@ -25,9 +31,9 @@ def get_base_model_for_pricing(model_name: str) -> str:
|
||||||
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
|
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
|
||||||
if moe_match:
|
if moe_match:
|
||||||
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
|
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
|
||||||
if total_billion <= 56:
|
if total_billion <= FIREWORKS_AI_56_B_MOE:
|
||||||
return "fireworks-ai-moe-up-to-56b"
|
return "fireworks-ai-moe-up-to-56b"
|
||||||
elif total_billion <= 176:
|
elif total_billion <= FIREWORKS_AI_176_B_MOE:
|
||||||
return "fireworks-ai-56b-to-176b"
|
return "fireworks-ai-56b-to-176b"
|
||||||
|
|
||||||
# Check for standard models in the form <number>b
|
# Check for standard models in the form <number>b
|
||||||
|
@ -37,9 +43,9 @@ def get_base_model_for_pricing(model_name: str) -> str:
|
||||||
params_billion = float(params_match)
|
params_billion = float(params_match)
|
||||||
|
|
||||||
# Determine the category based on the number of parameters
|
# Determine the category based on the number of parameters
|
||||||
if params_billion <= 16.0:
|
if params_billion <= FIREWORKS_AI_16_B:
|
||||||
return "fireworks-ai-up-to-16b"
|
return "fireworks-ai-up-to-16b"
|
||||||
elif params_billion <= 80.0:
|
elif params_billion <= FIREWORKS_AI_80_B:
|
||||||
return "fireworks-ai-16b-80b"
|
return "fireworks-ai-16b-80b"
|
||||||
|
|
||||||
# If no matches, return the original model_name
|
# If no matches, return the original model_name
|
||||||
|
|
|
@ -81,6 +81,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
|
||||||
"stop",
|
"stop",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
|
"modalities",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
|
|
|
@ -1,769 +0,0 @@
|
||||||
## Uses the huggingface text generation inference API
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
get_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
|
||||||
AsyncHTTPHandler,
|
|
||||||
HTTPHandler,
|
|
||||||
_get_httpx_client,
|
|
||||||
get_async_httpx_client,
|
|
||||||
)
|
|
||||||
from litellm.llms.huggingface.chat.transformation import (
|
|
||||||
HuggingfaceChatConfig as HuggingfaceConfig,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
|
||||||
from litellm.types.utils import EmbeddingResponse
|
|
||||||
from litellm.types.utils import Logprobs as TextCompletionLogprobs
|
|
||||||
from litellm.types.utils import ModelResponse
|
|
||||||
|
|
||||||
from ...base import BaseLLM
|
|
||||||
from ..common_utils import HuggingfaceError
|
|
||||||
|
|
||||||
hf_chat_config = HuggingfaceConfig()
|
|
||||||
|
|
||||||
|
|
||||||
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
|
|
||||||
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_embedding_for_model(
|
|
||||||
model: str, task_type: Optional[str], api_base: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
if task_type is not None:
|
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
|
||||||
return task_type
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Invalid task_type={}. Expected one of={}".format(
|
|
||||||
task_type, hf_tasks_embeddings
|
|
||||||
)
|
|
||||||
)
|
|
||||||
http_client = HTTPHandler(concurrent_limit=1)
|
|
||||||
|
|
||||||
model_info = http_client.get(url=api_base)
|
|
||||||
|
|
||||||
model_info_dict = model_info.json()
|
|
||||||
|
|
||||||
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
|
||||||
|
|
||||||
return pipeline_tag
|
|
||||||
|
|
||||||
|
|
||||||
async def async_get_hf_task_embedding_for_model(
|
|
||||||
model: str, task_type: Optional[str], api_base: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
if task_type is not None:
|
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
|
||||||
return task_type
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Invalid task_type={}. Expected one of={}".format(
|
|
||||||
task_type, hf_tasks_embeddings
|
|
||||||
)
|
|
||||||
)
|
|
||||||
http_client = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info = await http_client.get(url=api_base)
|
|
||||||
|
|
||||||
model_info_dict = model_info.json()
|
|
||||||
|
|
||||||
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
|
||||||
|
|
||||||
return pipeline_tag
|
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
|
||||||
client: Optional[AsyncHTTPHandler],
|
|
||||||
api_base: str,
|
|
||||||
headers: dict,
|
|
||||||
data: str,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
logging_obj,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
|
||||||
json_mode: bool,
|
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
|
||||||
if client is None:
|
|
||||||
client = litellm.module_level_aclient
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post(
|
|
||||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_headers = getattr(e, "headers", None)
|
|
||||||
error_response = getattr(e, "response", None)
|
|
||||||
if error_headers is None and error_response:
|
|
||||||
error_headers = getattr(error_response, "headers", None)
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=e.response.status_code,
|
|
||||||
message=str(await e.response.aread()),
|
|
||||||
headers=cast(dict, error_headers) if error_headers else None,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
|
||||||
if isinstance(e, exception):
|
|
||||||
raise e
|
|
||||||
raise HuggingfaceError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
# LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key="",
|
|
||||||
original_response=response, # Pass the completion stream for logging
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.aiter_lines(), response.headers
|
|
||||||
|
|
||||||
|
|
||||||
class Huggingface(BaseLLM):
|
|
||||||
_client_session: Optional[httpx.Client] = None
|
|
||||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def completion( # noqa: PLR0915
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: Optional[str],
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
timeout: float,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
custom_prompt_dict={},
|
|
||||||
acompletion: bool = False,
|
|
||||||
logger_fn=None,
|
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
||||||
headers: dict = {},
|
|
||||||
):
|
|
||||||
super().completion()
|
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
|
||||||
task, model = hf_chat_config.get_hf_task_for_model(model)
|
|
||||||
litellm_params["task"] = task
|
|
||||||
headers = hf_chat_config.validate_environment(
|
|
||||||
api_key=api_key,
|
|
||||||
headers=headers,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
)
|
|
||||||
completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
|
|
||||||
data = hf_chat_config.transform_request(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=data,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": completion_url,
|
|
||||||
"acompletion": acompletion,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
|
|
||||||
if acompletion is True:
|
|
||||||
### ASYNC STREAMING
|
|
||||||
if optional_params.get("stream", False):
|
|
||||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
|
|
||||||
else:
|
|
||||||
### ASYNC COMPLETION
|
|
||||||
return self.acompletion(
|
|
||||||
api_base=completion_url,
|
|
||||||
data=data,
|
|
||||||
headers=headers,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
model=model,
|
|
||||||
optional_params=optional_params,
|
|
||||||
timeout=timeout,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
|
||||||
client=(
|
|
||||||
client
|
|
||||||
if client is not None
|
|
||||||
and isinstance(client, AsyncHTTPHandler)
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
|
||||||
client = _get_httpx_client()
|
|
||||||
### SYNC STREAMING
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
response = client.post(
|
|
||||||
url=completion_url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=optional_params["stream"],
|
|
||||||
)
|
|
||||||
return response.iter_lines()
|
|
||||||
### SYNC COMPLETION
|
|
||||||
else:
|
|
||||||
response = client.post(
|
|
||||||
url=completion_url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
)
|
|
||||||
|
|
||||||
return hf_chat_config.transform_response(
|
|
||||||
model=model,
|
|
||||||
raw_response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key=api_key,
|
|
||||||
request_data=data,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
json_mode=None,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=e.response.status_code,
|
|
||||||
message=e.response.text,
|
|
||||||
headers=e.response.headers,
|
|
||||||
)
|
|
||||||
except HuggingfaceError as e:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
if exception_mapping_worked:
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
|
|
||||||
|
|
||||||
async def acompletion(
|
|
||||||
self,
|
|
||||||
api_base: str,
|
|
||||||
data: dict,
|
|
||||||
headers: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
encoding: Any,
|
|
||||||
model: str,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
timeout: float,
|
|
||||||
logging_obj: LiteLLMLoggingObj,
|
|
||||||
api_key: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
):
|
|
||||||
response: Optional[httpx.Response] = None
|
|
||||||
try:
|
|
||||||
if client is None:
|
|
||||||
client = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.HUGGINGFACE
|
|
||||||
)
|
|
||||||
### ASYNC COMPLETION
|
|
||||||
http_response = await client.post(
|
|
||||||
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
response = http_response
|
|
||||||
|
|
||||||
return hf_chat_config.transform_response(
|
|
||||||
model=model,
|
|
||||||
raw_response=http_response,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key=api_key,
|
|
||||||
request_data=data,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
json_mode=None,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, httpx.TimeoutException):
|
|
||||||
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
|
||||||
elif isinstance(e, HuggingfaceError):
|
|
||||||
raise e
|
|
||||||
elif response is not None and hasattr(response, "text"):
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=500,
|
|
||||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
|
||||||
headers=response.headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
|
|
||||||
|
|
||||||
async def async_streaming(
|
|
||||||
self,
|
|
||||||
logging_obj,
|
|
||||||
api_base: str,
|
|
||||||
data: dict,
|
|
||||||
headers: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
model: str,
|
|
||||||
timeout: float,
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
):
|
|
||||||
completion_stream, _ = await make_call(
|
|
||||||
client=client,
|
|
||||||
api_base=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
timeout=timeout,
|
|
||||||
json_mode=False,
|
|
||||||
)
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="huggingface",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return streamwrapper
|
|
||||||
|
|
||||||
def _transform_input_on_pipeline_tag(
|
|
||||||
self, input: List, pipeline_tag: Optional[str]
|
|
||||||
) -> dict:
|
|
||||||
if pipeline_tag is None:
|
|
||||||
return {"inputs": input}
|
|
||||||
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
|
|
||||||
if len(input) < 2:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=400,
|
|
||||||
message="sentence-similarity requires 2+ sentences",
|
|
||||||
)
|
|
||||||
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
|
||||||
elif pipeline_tag == "rerank":
|
|
||||||
if len(input) < 2:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=400,
|
|
||||||
message="reranker requires 2+ sentences",
|
|
||||||
)
|
|
||||||
return {"inputs": {"query": input[0], "texts": input[1:]}}
|
|
||||||
return {"inputs": input} # default to feature-extraction pipeline tag
|
|
||||||
|
|
||||||
async def _async_transform_input(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
task_type: Optional[str],
|
|
||||||
embed_url: str,
|
|
||||||
input: List,
|
|
||||||
optional_params: dict,
|
|
||||||
) -> dict:
|
|
||||||
hf_task = await async_get_hf_task_embedding_for_model(
|
|
||||||
model=model, task_type=task_type, api_base=embed_url
|
|
||||||
)
|
|
||||||
|
|
||||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
|
||||||
|
|
||||||
if len(optional_params.keys()) > 0:
|
|
||||||
data["options"] = optional_params
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
|
||||||
special_options_keys = HuggingfaceConfig().get_special_options_params()
|
|
||||||
special_parameters_keys = [
|
|
||||||
"min_length",
|
|
||||||
"max_length",
|
|
||||||
"top_k",
|
|
||||||
"top_p",
|
|
||||||
"temperature",
|
|
||||||
"repetition_penalty",
|
|
||||||
"max_time",
|
|
||||||
]
|
|
||||||
|
|
||||||
for k, v in optional_params.items():
|
|
||||||
if k in special_options_keys:
|
|
||||||
data.setdefault("options", {})
|
|
||||||
data["options"][k] = v
|
|
||||||
elif k in special_parameters_keys:
|
|
||||||
data.setdefault("parameters", {})
|
|
||||||
data["parameters"][k] = v
|
|
||||||
else:
|
|
||||||
data[k] = v
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _transform_input(
|
|
||||||
self,
|
|
||||||
input: List,
|
|
||||||
model: str,
|
|
||||||
call_type: Literal["sync", "async"],
|
|
||||||
optional_params: dict,
|
|
||||||
embed_url: str,
|
|
||||||
) -> dict:
|
|
||||||
data: Dict = {}
|
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
|
||||||
if "sentence-transformers" in model:
|
|
||||||
if len(input) == 0:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=400,
|
|
||||||
message="sentence transformers requires 2+ sentences",
|
|
||||||
)
|
|
||||||
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
|
||||||
else:
|
|
||||||
data = {"inputs": input}
|
|
||||||
|
|
||||||
task_type = optional_params.pop("input_type", None)
|
|
||||||
|
|
||||||
if call_type == "sync":
|
|
||||||
hf_task = get_hf_task_embedding_for_model(
|
|
||||||
model=model, task_type=task_type, api_base=embed_url
|
|
||||||
)
|
|
||||||
elif call_type == "async":
|
|
||||||
return self._async_transform_input(
|
|
||||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
data = self._transform_input_on_pipeline_tag(
|
|
||||||
input=input, pipeline_tag=hf_task
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(optional_params.keys()) > 0:
|
|
||||||
data = self._process_optional_params(
|
|
||||||
data=data, optional_params=optional_params
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_embedding_response(
|
|
||||||
self,
|
|
||||||
embeddings: dict,
|
|
||||||
model_response: EmbeddingResponse,
|
|
||||||
model: str,
|
|
||||||
input: List,
|
|
||||||
encoding: Any,
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
output_data = []
|
|
||||||
if "similarities" in embeddings:
|
|
||||||
for idx, embedding in embeddings["similarities"]:
|
|
||||||
output_data.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": embedding, # flatten list returned from hf
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for idx, embedding in enumerate(embeddings):
|
|
||||||
if isinstance(embedding, float):
|
|
||||||
output_data.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": embedding, # flatten list returned from hf
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
|
||||||
output_data.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": embedding, # flatten list returned from hf
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_data.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": embedding[0][
|
|
||||||
0
|
|
||||||
], # flatten list returned from hf
|
|
||||||
}
|
|
||||||
)
|
|
||||||
model_response.object = "list"
|
|
||||||
model_response.data = output_data
|
|
||||||
model_response.model = model
|
|
||||||
input_tokens = 0
|
|
||||||
for text in input:
|
|
||||||
input_tokens += len(encoding.encode(text))
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
litellm.Usage(
|
|
||||||
prompt_tokens=input_tokens,
|
|
||||||
completion_tokens=input_tokens,
|
|
||||||
total_tokens=input_tokens,
|
|
||||||
prompt_tokens_details=None,
|
|
||||||
completion_tokens_details=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def aembedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: list,
|
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
logging_obj: LiteLLMLoggingObj,
|
|
||||||
optional_params: dict,
|
|
||||||
api_base: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
headers: dict,
|
|
||||||
encoding: Callable,
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
):
|
|
||||||
## TRANSFORMATION ##
|
|
||||||
data = self._transform_input(
|
|
||||||
input=input,
|
|
||||||
model=model,
|
|
||||||
call_type="sync",
|
|
||||||
optional_params=optional_params,
|
|
||||||
embed_url=api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": api_base,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
if client is None:
|
|
||||||
client = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
original_response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = response.json()
|
|
||||||
|
|
||||||
if "error" in embeddings:
|
|
||||||
raise HuggingfaceError(status_code=500, message=embeddings["error"])
|
|
||||||
|
|
||||||
## PROCESS RESPONSE ##
|
|
||||||
return self._process_embedding_response(
|
|
||||||
embeddings=embeddings,
|
|
||||||
model_response=model_response,
|
|
||||||
model=model,
|
|
||||||
input=input,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
def embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: list,
|
|
||||||
model_response: EmbeddingResponse,
|
|
||||||
optional_params: dict,
|
|
||||||
logging_obj: LiteLLMLoggingObj,
|
|
||||||
encoding: Callable,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
|
||||||
aembedding: Optional[bool] = None,
|
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
||||||
headers={},
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
super().embedding()
|
|
||||||
headers = hf_chat_config.validate_environment(
|
|
||||||
api_key=api_key,
|
|
||||||
headers=headers,
|
|
||||||
model=model,
|
|
||||||
optional_params=optional_params,
|
|
||||||
messages=[],
|
|
||||||
)
|
|
||||||
# print_verbose(f"{model}, {task}")
|
|
||||||
embed_url = ""
|
|
||||||
if "https" in model:
|
|
||||||
embed_url = model
|
|
||||||
elif api_base:
|
|
||||||
embed_url = api_base
|
|
||||||
elif "HF_API_BASE" in os.environ:
|
|
||||||
embed_url = os.getenv("HF_API_BASE", "")
|
|
||||||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
|
||||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
|
||||||
else:
|
|
||||||
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
|
||||||
|
|
||||||
## ROUTING ##
|
|
||||||
if aembedding is True:
|
|
||||||
return self.aembedding(
|
|
||||||
input=input,
|
|
||||||
model_response=model_response,
|
|
||||||
timeout=timeout,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
headers=headers,
|
|
||||||
api_base=embed_url, # type: ignore
|
|
||||||
api_key=api_key,
|
|
||||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
|
||||||
model=model,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
|
||||||
|
|
||||||
data = self._transform_input(
|
|
||||||
input=input,
|
|
||||||
model=model,
|
|
||||||
call_type="sync",
|
|
||||||
optional_params=optional_params,
|
|
||||||
embed_url=embed_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": embed_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
|
||||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
original_response=response,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = response.json()
|
|
||||||
|
|
||||||
if "error" in embeddings:
|
|
||||||
raise HuggingfaceError(status_code=500, message=embeddings["error"])
|
|
||||||
|
|
||||||
## PROCESS RESPONSE ##
|
|
||||||
return self._process_embedding_response(
|
|
||||||
embeddings=embeddings,
|
|
||||||
model_response=model_response,
|
|
||||||
model=model,
|
|
||||||
input=input,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _transform_logprobs(
|
|
||||||
self, hf_response: Optional[List]
|
|
||||||
) -> Optional[TextCompletionLogprobs]:
|
|
||||||
"""
|
|
||||||
Transform Hugging Face logprobs to OpenAI.Completion() format
|
|
||||||
"""
|
|
||||||
if hf_response is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Initialize an empty list for the transformed logprobs
|
|
||||||
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
|
|
||||||
text_offset=[],
|
|
||||||
token_logprobs=[],
|
|
||||||
tokens=[],
|
|
||||||
top_logprobs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# For each Hugging Face response, transform the logprobs
|
|
||||||
for response in hf_response:
|
|
||||||
# Extract the relevant information from the response
|
|
||||||
response_details = response["details"]
|
|
||||||
top_tokens = response_details.get("top_tokens", {})
|
|
||||||
|
|
||||||
for i, token in enumerate(response_details["prefill"]):
|
|
||||||
# Extract the text of the token
|
|
||||||
token_text = token["text"]
|
|
||||||
|
|
||||||
# Extract the logprob of the token
|
|
||||||
token_logprob = token["logprob"]
|
|
||||||
|
|
||||||
# Add the token information to the 'token_info' list
|
|
||||||
cast(List[str], _logprob.tokens).append(token_text)
|
|
||||||
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
|
||||||
|
|
||||||
# stub this to work with llm eval harness
|
|
||||||
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
|
||||||
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
|
||||||
top_alt_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# For each element in the 'tokens' list, extract the relevant information
|
|
||||||
for i, token in enumerate(response_details["tokens"]):
|
|
||||||
# Extract the text of the token
|
|
||||||
token_text = token["text"]
|
|
||||||
|
|
||||||
# Extract the logprob of the token
|
|
||||||
token_logprob = token["logprob"]
|
|
||||||
|
|
||||||
top_alt_tokens = {}
|
|
||||||
temp_top_logprobs = []
|
|
||||||
if top_tokens != {}:
|
|
||||||
temp_top_logprobs = top_tokens[i]
|
|
||||||
|
|
||||||
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
|
|
||||||
for elem in temp_top_logprobs:
|
|
||||||
text = elem["text"]
|
|
||||||
logprob = elem["logprob"]
|
|
||||||
top_alt_tokens[text] = logprob
|
|
||||||
|
|
||||||
# Add the token information to the 'token_info' list
|
|
||||||
cast(List[str], _logprob.tokens).append(token_text)
|
|
||||||
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
|
||||||
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
|
||||||
top_alt_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the text offset of the token
|
|
||||||
# This is computed as the sum of the lengths of all previous tokens
|
|
||||||
cast(List[int], _logprob.text_offset).append(
|
|
||||||
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
|
||||||
)
|
|
||||||
|
|
||||||
return _logprob
|
|
|
@ -1,27 +1,10 @@
|
||||||
import json
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
from copy import deepcopy
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|
||||||
convert_content_list_to_str,
|
|
||||||
)
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
|
||||||
custom_prompt,
|
|
||||||
prompt_factory,
|
|
||||||
)
|
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
|
||||||
from litellm.utils import token_counter
|
|
||||||
|
|
||||||
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
@ -30,176 +13,98 @@ if TYPE_CHECKING:
|
||||||
else:
|
else:
|
||||||
LoggingClass = Any
|
LoggingClass = Any
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
tgi_models_cache = None
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
conv_models_cache = None
|
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceChatConfig(BaseConfig):
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BASE_URL = "https://router.huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hf_task: Optional[
|
def validate_environment(
|
||||||
hf_tasks
|
|
||||||
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
|
||||||
best_of: Optional[int] = None
|
|
||||||
decoder_input_details: Optional[bool] = None
|
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
|
||||||
max_new_tokens: Optional[int] = None
|
|
||||||
repetition_penalty: Optional[float] = None
|
|
||||||
return_full_text: Optional[
|
|
||||||
bool
|
|
||||||
] = False # by default don't return the input as part of the output
|
|
||||||
seed: Optional[int] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
top_n_tokens: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
truncate: Optional[int] = None
|
|
||||||
typical_p: Optional[float] = None
|
|
||||||
watermark: Optional[bool] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
self,
|
||||||
best_of: Optional[int] = None,
|
headers: dict,
|
||||||
decoder_input_details: Optional[bool] = None,
|
|
||||||
details: Optional[bool] = None,
|
|
||||||
max_new_tokens: Optional[int] = None,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
return_full_text: Optional[bool] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_n_tokens: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
truncate: Optional[int] = None,
|
|
||||||
typical_p: Optional[float] = None,
|
|
||||||
watermark: Optional[bool] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals().copy()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return super().get_config()
|
|
||||||
|
|
||||||
def get_special_options_params(self):
|
|
||||||
return ["use_cache", "wait_for_model"]
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str):
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"temperature",
|
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"top_p",
|
|
||||||
"stop",
|
|
||||||
"n",
|
|
||||||
"echo",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self,
|
|
||||||
non_default_params: Dict,
|
|
||||||
optional_params: Dict,
|
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
messages: List[AllMessageValues],
|
||||||
) -> Dict:
|
optional_params: dict,
|
||||||
for param, value in non_default_params.items():
|
api_key: Optional[str] = None,
|
||||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
api_base: Optional[str] = None,
|
||||||
if param == "temperature":
|
) -> dict:
|
||||||
if value == 0.0 or value == 0:
|
default_headers = {
|
||||||
# hugging face exception raised when temp==0
|
"content-type": "application/json",
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
}
|
||||||
value = 0.01
|
if api_key is not None:
|
||||||
optional_params["temperature"] = value
|
default_headers["Authorization"] = f"Bearer {api_key}"
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "n":
|
|
||||||
optional_params["best_of"] = value
|
|
||||||
optional_params[
|
|
||||||
"do_sample"
|
|
||||||
] = True # Need to sample if you want best of for hf inference endpoints
|
|
||||||
if param == "stream":
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop"] = value
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
# HF TGI raises the following exception when max_new_tokens==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
|
||||||
if value == 0:
|
|
||||||
value = 1
|
|
||||||
optional_params["max_new_tokens"] = value
|
|
||||||
if param == "echo":
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
|
||||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
|
||||||
optional_params["decoder_input_details"] = True
|
|
||||||
|
|
||||||
return optional_params
|
headers = {**headers, **default_headers}
|
||||||
|
|
||||||
def get_hf_api_key(self) -> Optional[str]:
|
return headers
|
||||||
return get_secret_str("HUGGINGFACE_API_KEY")
|
|
||||||
|
|
||||||
def read_tgi_conv_models(self):
|
def get_error_class(
|
||||||
try:
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
global tgi_models_cache, conv_models_cache
|
) -> BaseLLMException:
|
||||||
# Check if the cache is already populated
|
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
||||||
# so we don't keep on reading txt file if there are 1k requests
|
|
||||||
if (tgi_models_cache is not None) and (conv_models_cache is not None):
|
|
||||||
return tgi_models_cache, conv_models_cache
|
|
||||||
# If not, read the file and populate the cache
|
|
||||||
tgi_models = set()
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
script_directory = os.path.dirname(script_directory)
|
|
||||||
# Construct the file path relative to the script's directory
|
|
||||||
file_path = os.path.join(
|
|
||||||
script_directory,
|
|
||||||
"huggingface_llms_metadata",
|
|
||||||
"hf_text_generation_models.txt",
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(file_path, "r") as file:
|
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||||
for line in file:
|
"""
|
||||||
tgi_models.add(line.strip())
|
Get the API base for the Huggingface API.
|
||||||
|
|
||||||
# Cache the set for future use
|
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
||||||
tgi_models_cache = tgi_models
|
"""
|
||||||
|
if model.startswith(("http://", "https://")):
|
||||||
|
base_url = model
|
||||||
|
elif base_url is None:
|
||||||
|
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
return base_url
|
||||||
|
|
||||||
# If not, read the file and populate the cache
|
def get_complete_url(
|
||||||
file_path = os.path.join(
|
self,
|
||||||
script_directory,
|
api_base: Optional[str],
|
||||||
"huggingface_llms_metadata",
|
api_key: Optional[str],
|
||||||
"hf_conversational_models.txt",
|
model: str,
|
||||||
)
|
optional_params: dict,
|
||||||
conv_models = set()
|
litellm_params: dict,
|
||||||
with open(file_path, "r") as file:
|
stream: Optional[bool] = None,
|
||||||
for line in file:
|
) -> str:
|
||||||
conv_models.add(line.strip())
|
"""
|
||||||
# Cache the set for future use
|
Get the complete URL for the API call.
|
||||||
conv_models_cache = conv_models
|
For provider-specific routing through huggingface
|
||||||
return tgi_models, conv_models
|
"""
|
||||||
except Exception:
|
# 1. Check if api_base is provided
|
||||||
return set(), set()
|
if api_base is not None:
|
||||||
|
complete_url = api_base
|
||||||
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
|
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||||
# read text file, cast it to set
|
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
elif model.startswith(("http://", "https://")):
|
||||||
if model.split("/")[0] in hf_task_list:
|
complete_url = model
|
||||||
split_model = model.split("/", 1)
|
# 4. Default construction with provider
|
||||||
return split_model[0], split_model[1] # type: ignore
|
|
||||||
tgi_models, conversational_models = self.read_tgi_conv_models()
|
|
||||||
|
|
||||||
if model in tgi_models:
|
|
||||||
return "text-generation-inference", model
|
|
||||||
elif model in conversational_models:
|
|
||||||
return "conversational", model
|
|
||||||
elif "roneneldan/TinyStories" in model:
|
|
||||||
return "text-generation", model
|
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference", model # default to tgi
|
# Parse provider and model
|
||||||
|
first_part, remaining = model.split("/", 1)
|
||||||
|
if "/" in remaining:
|
||||||
|
provider = first_part
|
||||||
|
else:
|
||||||
|
provider = "hf-inference"
|
||||||
|
|
||||||
|
if provider == "hf-inference":
|
||||||
|
route = f"{provider}/models/{model}/v1/chat/completions"
|
||||||
|
elif provider == "novita":
|
||||||
|
route = f"{provider}/chat/completions"
|
||||||
|
else:
|
||||||
|
route = f"{provider}/v1/chat/completions"
|
||||||
|
complete_url = f"{BASE_URL}/{route}"
|
||||||
|
|
||||||
|
# Ensure URL doesn't end with a slash
|
||||||
|
complete_url = complete_url.rstrip("/")
|
||||||
|
return complete_url
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
|
@ -209,381 +114,28 @@ class HuggingfaceChatConfig(BaseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
task = litellm_params.get("task", None)
|
if "max_retries" in optional_params:
|
||||||
## VALIDATE API FORMAT
|
logger.warning("`max_retries` is not supported. It will be ignored.")
|
||||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
optional_params.pop("max_retries", None)
|
||||||
raise Exception(
|
first_part, remaining = model.split("/", 1)
|
||||||
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
if "/" in remaining:
|
||||||
)
|
provider = first_part
|
||||||
|
model_id = remaining
|
||||||
## Load Config
|
|
||||||
config = litellm.HuggingfaceConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
### MAP INPUT PARAMS
|
|
||||||
#### HANDLE SPECIAL PARAMS
|
|
||||||
special_params = self.get_special_options_params()
|
|
||||||
special_params_dict = {}
|
|
||||||
# Create a list of keys to pop after iteration
|
|
||||||
keys_to_pop = []
|
|
||||||
|
|
||||||
for k, v in optional_params.items():
|
|
||||||
if k in special_params:
|
|
||||||
special_params_dict[k] = v
|
|
||||||
keys_to_pop.append(k)
|
|
||||||
|
|
||||||
# Pop the keys from the dictionary after iteration
|
|
||||||
for k in keys_to_pop:
|
|
||||||
optional_params.pop(k)
|
|
||||||
if task == "conversational":
|
|
||||||
inference_params = deepcopy(optional_params)
|
|
||||||
inference_params.pop("details")
|
|
||||||
inference_params.pop("return_full_text")
|
|
||||||
past_user_inputs = []
|
|
||||||
generated_responses = []
|
|
||||||
text = ""
|
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "user":
|
|
||||||
if text != "":
|
|
||||||
past_user_inputs.append(text)
|
|
||||||
text = convert_content_list_to_str(message)
|
|
||||||
elif message["role"] == "assistant" or message["role"] == "system":
|
|
||||||
generated_responses.append(convert_content_list_to_str(message))
|
|
||||||
data = {
|
|
||||||
"inputs": {
|
|
||||||
"text": text,
|
|
||||||
"past_user_inputs": past_user_inputs,
|
|
||||||
"generated_responses": generated_responses,
|
|
||||||
},
|
|
||||||
"parameters": inference_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif task == "text-generation-inference":
|
|
||||||
# always send "details" and "return_full_text" as params
|
|
||||||
if model in litellm.custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details.get("roles", None),
|
|
||||||
initial_prompt_value=model_prompt_details.get(
|
|
||||||
"initial_prompt_value", ""
|
|
||||||
),
|
|
||||||
final_prompt_value=model_prompt_details.get(
|
|
||||||
"final_prompt_value", ""
|
|
||||||
),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
|
||||||
data = {
|
|
||||||
"inputs": prompt, # type: ignore
|
|
||||||
"parameters": optional_params,
|
|
||||||
"stream": ( # type: ignore
|
|
||||||
True
|
|
||||||
if "stream" in optional_params
|
|
||||||
and isinstance(optional_params["stream"], bool)
|
|
||||||
and optional_params["stream"] is True # type: ignore
|
|
||||||
else False
|
|
||||||
),
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# Non TGI and Conversational llms
|
provider = "hf-inference"
|
||||||
# We need this branch, it removes 'details' and 'return_full_text' from params
|
model_id = model
|
||||||
if model in litellm.custom_prompt_dict:
|
provider_mapping = _fetch_inference_provider_mapping(model_id)
|
||||||
# check if the model has a registered custom prompt
|
if provider not in provider_mapping:
|
||||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
raise HuggingFaceError(
|
||||||
prompt = custom_prompt(
|
message=f"Model {model_id} is not supported for provider {provider}",
|
||||||
role_dict=model_prompt_details.get("roles", {}),
|
status_code=404,
|
||||||
initial_prompt_value=model_prompt_details.get(
|
headers={},
|
||||||
"initial_prompt_value", ""
|
|
||||||
),
|
|
||||||
final_prompt_value=model_prompt_details.get(
|
|
||||||
"final_prompt_value", ""
|
|
||||||
),
|
|
||||||
bos_token=model_prompt_details.get("bos_token", ""),
|
|
||||||
eos_token=model_prompt_details.get("eos_token", ""),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
|
||||||
inference_params = deepcopy(optional_params)
|
|
||||||
inference_params.pop("details")
|
|
||||||
inference_params.pop("return_full_text")
|
|
||||||
data = {
|
|
||||||
"inputs": prompt, # type: ignore
|
|
||||||
}
|
|
||||||
if task == "text-generation-inference":
|
|
||||||
data["parameters"] = inference_params
|
|
||||||
data["stream"] = ( # type: ignore
|
|
||||||
True # type: ignore
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
|
|
||||||
### RE-ADD SPECIAL PARAMS
|
|
||||||
if len(special_params_dict.keys()) > 0:
|
|
||||||
data.update({"options": special_params_dict})
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def get_api_base(self, api_base: Optional[str], model: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the API base for the Huggingface API.
|
|
||||||
|
|
||||||
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
|
||||||
"""
|
|
||||||
if "https" in model:
|
|
||||||
completion_url = model
|
|
||||||
elif api_base is not None:
|
|
||||||
completion_url = api_base
|
|
||||||
elif "HF_API_BASE" in os.environ:
|
|
||||||
completion_url = os.getenv("HF_API_BASE", "")
|
|
||||||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
|
||||||
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
|
||||||
else:
|
|
||||||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
|
||||||
|
|
||||||
return completion_url
|
|
||||||
|
|
||||||
def validate_environment(
|
|
||||||
self,
|
|
||||||
headers: Dict,
|
|
||||||
model: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
optional_params: Dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
) -> Dict:
|
|
||||||
default_headers = {
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
if api_key is not None:
|
|
||||||
default_headers[
|
|
||||||
"Authorization"
|
|
||||||
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
|
||||||
|
|
||||||
headers = {**headers, **default_headers}
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def get_error_class(
|
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
||||||
) -> BaseLLMException:
|
|
||||||
return HuggingfaceError(
|
|
||||||
status_code=status_code, message=error_message, headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_streamed_response_to_complete_response(
|
|
||||||
self,
|
|
||||||
response: httpx.Response,
|
|
||||||
logging_obj: LoggingClass,
|
|
||||||
model: str,
|
|
||||||
data: dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
streamed_response = CustomStreamWrapper(
|
|
||||||
completion_stream=response.iter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="huggingface",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
content = ""
|
|
||||||
for chunk in streamed_response:
|
|
||||||
content += chunk["choices"][0]["delta"]["content"]
|
|
||||||
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=data,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=completion_response,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
return completion_response
|
|
||||||
|
|
||||||
def convert_to_model_response_object( # noqa: PLR0915
|
|
||||||
self,
|
|
||||||
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
|
|
||||||
model_response: ModelResponse,
|
|
||||||
task: Optional[hf_tasks],
|
|
||||||
optional_params: dict,
|
|
||||||
encoding: Any,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
model: str,
|
|
||||||
):
|
|
||||||
if task is None:
|
|
||||||
task = "text-generation-inference" # default to tgi
|
|
||||||
|
|
||||||
if task == "conversational":
|
|
||||||
if len(completion_response["generated_text"]) > 0: # type: ignore
|
|
||||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
|
||||||
"generated_text"
|
|
||||||
]
|
|
||||||
elif task == "text-generation-inference":
|
|
||||||
if (
|
|
||||||
not isinstance(completion_response, list)
|
|
||||||
or not isinstance(completion_response[0], dict)
|
|
||||||
or "generated_text" not in completion_response[0]
|
|
||||||
):
|
|
||||||
raise HuggingfaceError(
|
|
||||||
status_code=422,
|
|
||||||
message=f"response is not in expected format - {completion_response}",
|
|
||||||
headers=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
|
||||||
model_response.choices[0].message.content = output_parser( # type: ignore
|
|
||||||
completion_response[0]["generated_text"]
|
|
||||||
)
|
|
||||||
## GETTING LOGPROBS + FINISH REASON
|
|
||||||
if (
|
|
||||||
"details" in completion_response[0]
|
|
||||||
and "tokens" in completion_response[0]["details"]
|
|
||||||
):
|
|
||||||
model_response.choices[0].finish_reason = completion_response[0][
|
|
||||||
"details"
|
|
||||||
]["finish_reason"]
|
|
||||||
sum_logprob = 0
|
|
||||||
for token in completion_response[0]["details"]["tokens"]:
|
|
||||||
if token["logprob"] is not None:
|
|
||||||
sum_logprob += token["logprob"]
|
|
||||||
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
|
||||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
|
||||||
if (
|
|
||||||
"details" in completion_response[0]
|
|
||||||
and "best_of_sequences" in completion_response[0]["details"]
|
|
||||||
):
|
|
||||||
choices_list = []
|
|
||||||
for idx, item in enumerate(
|
|
||||||
completion_response[0]["details"]["best_of_sequences"]
|
|
||||||
):
|
|
||||||
sum_logprob = 0
|
|
||||||
for token in item["tokens"]:
|
|
||||||
if token["logprob"] is not None:
|
|
||||||
sum_logprob += token["logprob"]
|
|
||||||
if len(item["generated_text"]) > 0:
|
|
||||||
message_obj = Message(
|
|
||||||
content=output_parser(item["generated_text"]),
|
|
||||||
logprobs=sum_logprob,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_obj = Message(content=None)
|
|
||||||
choice_obj = Choices(
|
|
||||||
finish_reason=item["finish_reason"],
|
|
||||||
index=idx + 1,
|
|
||||||
message=message_obj,
|
|
||||||
)
|
|
||||||
choices_list.append(choice_obj)
|
|
||||||
model_response.choices.extend(choices_list)
|
|
||||||
elif task == "text-classification":
|
|
||||||
model_response.choices[0].message.content = json.dumps( # type: ignore
|
|
||||||
completion_response
|
|
||||||
)
|
)
|
||||||
else:
|
provider_mapping = provider_mapping[provider]
|
||||||
if (
|
if provider_mapping["status"] == "staging":
|
||||||
isinstance(completion_response, list)
|
logger.warning(
|
||||||
and len(completion_response[0]["generated_text"]) > 0
|
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
|
||||||
):
|
|
||||||
model_response.choices[0].message.content = output_parser( # type: ignore
|
|
||||||
completion_response[0]["generated_text"]
|
|
||||||
)
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt_tokens = 0
|
|
||||||
try:
|
|
||||||
prompt_tokens = token_counter(model=model, messages=messages)
|
|
||||||
except Exception:
|
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
|
||||||
pass
|
|
||||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
|
||||||
if output_text is not None and len(output_text) > 0:
|
|
||||||
completion_tokens = 0
|
|
||||||
try:
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(
|
|
||||||
model_response["choices"][0]["message"].get("content", "")
|
|
||||||
)
|
|
||||||
) ##[TODO] use the llama2 tokenizer here
|
|
||||||
except Exception:
|
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
completion_tokens = 0
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
model_response._hidden_params["original_response"] = completion_response
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
def transform_response(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
raw_response: httpx.Response,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
logging_obj: LoggingClass,
|
|
||||||
request_data: Dict,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
optional_params: Dict,
|
|
||||||
litellm_params: Dict,
|
|
||||||
encoding: Any,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
json_mode: Optional[bool] = None,
|
|
||||||
) -> ModelResponse:
|
|
||||||
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
|
||||||
task = litellm_params.get("task", None)
|
|
||||||
is_streamed = False
|
|
||||||
if (
|
|
||||||
raw_response.__dict__["headers"].get("Content-Type", "")
|
|
||||||
== "text/event-stream"
|
|
||||||
):
|
|
||||||
is_streamed = True
|
|
||||||
|
|
||||||
# iterate over the complete streamed response, and return the final answer
|
|
||||||
if is_streamed:
|
|
||||||
completion_response = self._convert_streamed_response_to_complete_response(
|
|
||||||
response=raw_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
model=model,
|
|
||||||
data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
)
|
||||||
else:
|
mapped_model = provider_mapping["providerId"]
|
||||||
## LOGGING
|
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||||
logging_obj.post_call(
|
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|
||||||
input=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=raw_response.text,
|
|
||||||
additional_args={"complete_input_dict": request_data},
|
|
||||||
)
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
try:
|
|
||||||
completion_response = raw_response.json()
|
|
||||||
if isinstance(completion_response, dict):
|
|
||||||
completion_response = [completion_response]
|
|
||||||
except Exception:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
message=f"Original Response received: {raw_response.text}",
|
|
||||||
status_code=raw_response.status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(completion_response, dict) and "error" in completion_response:
|
|
||||||
raise HuggingfaceError(
|
|
||||||
message=completion_response["error"], # type: ignore
|
|
||||||
status_code=raw_response.status_code,
|
|
||||||
)
|
|
||||||
return self.convert_to_model_response_object(
|
|
||||||
completion_response=completion_response,
|
|
||||||
model_response=model_response,
|
|
||||||
task=task if task is not None and task in hf_task_list else None,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,18 +1,30 @@
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
HF_HUB_URL = "https://huggingface.co"
|
||||||
|
|
||||||
class HuggingfaceError(BaseLLMException):
|
|
||||||
|
class HuggingFaceError(BaseLLMException):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
status_code: int,
|
status_code,
|
||||||
message: str,
|
message,
|
||||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
request=request,
|
||||||
|
response=response,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
hf_tasks = Literal[
|
hf_tasks = Literal[
|
||||||
|
@ -43,3 +55,48 @@ def output_parser(generated_text: str):
|
||||||
if generated_text.endswith(token):
|
if generated_text.endswith(token):
|
||||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
return generated_text
|
return generated_text
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _fetch_inference_provider_mapping(model: str) -> dict:
|
||||||
|
"""
|
||||||
|
Fetch provider mappings for a model from the Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The inference provider mapping for the model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no provider mapping is found
|
||||||
|
HuggingFaceError: If the API request fails
|
||||||
|
"""
|
||||||
|
headers = {"Accept": "application/json"}
|
||||||
|
if os.getenv("HUGGINGFACE_API_KEY"):
|
||||||
|
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
|
||||||
|
|
||||||
|
path = f"{HF_HUB_URL}/api/models/{model}"
|
||||||
|
params = {"expand": ["inferenceProviderMapping"]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.get(path, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
provider_mapping = response.json().get("inferenceProviderMapping")
|
||||||
|
|
||||||
|
if provider_mapping is None:
|
||||||
|
raise ValueError(f"No provider mapping found for model {model}")
|
||||||
|
|
||||||
|
return provider_mapping
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
if hasattr(e, "response"):
|
||||||
|
status_code = getattr(e.response, "status_code", 500)
|
||||||
|
headers = getattr(e.response, "headers", {})
|
||||||
|
else:
|
||||||
|
status_code = 500
|
||||||
|
headers = {}
|
||||||
|
raise HuggingFaceError(
|
||||||
|
message=f"Failed to fetch provider mapping: {str(e)}",
|
||||||
|
status_code=status_code,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
421
litellm/llms/huggingface/embedding/handler.py
Normal file
421
litellm/llms/huggingface/embedding/handler.py
Normal file
|
@ -0,0 +1,421 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
|
|
||||||
|
from ...base import BaseLLM
|
||||||
|
from ..common_utils import HuggingFaceError
|
||||||
|
from .transformation import HuggingFaceEmbeddingConfig
|
||||||
|
|
||||||
|
config = HuggingFaceEmbeddingConfig()
|
||||||
|
|
||||||
|
HF_HUB_URL = "https://huggingface.co"
|
||||||
|
|
||||||
|
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
|
||||||
|
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||||
|
if task_type is not None:
|
||||||
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
|
return task_type
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid task_type={}. Expected one of={}".format(
|
||||||
|
task_type, hf_tasks_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
http_client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
model_info = http_client.get(url=f"{api_base}/api/models/{model}")
|
||||||
|
|
||||||
|
model_info_dict = model_info.json()
|
||||||
|
|
||||||
|
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
||||||
|
|
||||||
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||||
|
if task_type is not None:
|
||||||
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
|
return task_type
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid task_type={}. Expected one of={}".format(
|
||||||
|
task_type, hf_tasks_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
http_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = await http_client.get(url=f"{api_base}/api/models/{model}")
|
||||||
|
|
||||||
|
model_info_dict = model_info.json()
|
||||||
|
|
||||||
|
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
||||||
|
|
||||||
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEmbedding(BaseLLM):
|
||||||
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _transform_input_on_pipeline_tag(
|
||||||
|
self, input: List, pipeline_tag: Optional[str]
|
||||||
|
) -> dict:
|
||||||
|
if pipeline_tag is None:
|
||||||
|
return {"inputs": input}
|
||||||
|
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
|
||||||
|
if len(input) < 2:
|
||||||
|
raise HuggingFaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="sentence-similarity requires 2+ sentences",
|
||||||
|
)
|
||||||
|
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
||||||
|
elif pipeline_tag == "rerank":
|
||||||
|
if len(input) < 2:
|
||||||
|
raise HuggingFaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="reranker requires 2+ sentences",
|
||||||
|
)
|
||||||
|
return {"inputs": {"query": input[0], "texts": input[1:]}}
|
||||||
|
return {"inputs": input} # default to feature-extraction pipeline tag
|
||||||
|
|
||||||
|
async def _async_transform_input(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
task_type: Optional[str],
|
||||||
|
embed_url: str,
|
||||||
|
input: List,
|
||||||
|
optional_params: dict,
|
||||||
|
) -> dict:
|
||||||
|
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||||
|
|
||||||
|
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data["options"] = optional_params
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
||||||
|
special_options_keys = config.get_special_options_params()
|
||||||
|
special_parameters_keys = [
|
||||||
|
"min_length",
|
||||||
|
"max_length",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"temperature",
|
||||||
|
"repetition_penalty",
|
||||||
|
"max_time",
|
||||||
|
]
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_options_keys:
|
||||||
|
data.setdefault("options", {})
|
||||||
|
data["options"][k] = v
|
||||||
|
elif k in special_parameters_keys:
|
||||||
|
data.setdefault("parameters", {})
|
||||||
|
data["parameters"][k] = v
|
||||||
|
else:
|
||||||
|
data[k] = v
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _transform_input(
|
||||||
|
self,
|
||||||
|
input: List,
|
||||||
|
model: str,
|
||||||
|
call_type: Literal["sync", "async"],
|
||||||
|
optional_params: dict,
|
||||||
|
embed_url: str,
|
||||||
|
) -> dict:
|
||||||
|
data: Dict = {}
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
if "sentence-transformers" in model:
|
||||||
|
if len(input) == 0:
|
||||||
|
raise HuggingFaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="sentence transformers requires 2+ sentences",
|
||||||
|
)
|
||||||
|
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
||||||
|
else:
|
||||||
|
data = {"inputs": input}
|
||||||
|
|
||||||
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
|
||||||
|
if call_type == "sync":
|
||||||
|
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||||
|
elif call_type == "async":
|
||||||
|
return self._async_transform_input(
|
||||||
|
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
data = self._transform_input_on_pipeline_tag(
|
||||||
|
input=input, pipeline_tag=hf_task
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data = self._process_optional_params(
|
||||||
|
data=data, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _process_embedding_response(
|
||||||
|
self,
|
||||||
|
embeddings: dict,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
model: str,
|
||||||
|
input: List,
|
||||||
|
encoding: Any,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
output_data = []
|
||||||
|
if "similarities" in embeddings:
|
||||||
|
for idx, embedding in embeddings["similarities"]:
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
if isinstance(embedding, float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding[0][
|
||||||
|
0
|
||||||
|
], # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = output_data
|
||||||
|
model_response.model = model
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens += len(encoding.encode(text))
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=input_tokens,
|
||||||
|
total_tokens=input_tokens,
|
||||||
|
prompt_tokens_details=None,
|
||||||
|
completion_tokens_details=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
api_base: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
headers: dict,
|
||||||
|
encoding: Callable,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
data = self._transform_input(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
call_type="sync",
|
||||||
|
optional_params=optional_params,
|
||||||
|
embed_url=api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None:
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
|
||||||
|
if "error" in embeddings:
|
||||||
|
raise HuggingFaceError(status_code=500, message=embeddings["error"])
|
||||||
|
|
||||||
|
## PROCESS RESPONSE ##
|
||||||
|
return self._process_embedding_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
encoding: Callable,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
|
aembedding: Optional[bool] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
headers={},
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
super().embedding()
|
||||||
|
headers = config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
messages=[],
|
||||||
|
)
|
||||||
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||||
|
# print_verbose(f"{model}, {task}")
|
||||||
|
embed_url = ""
|
||||||
|
if "https" in model:
|
||||||
|
embed_url = model
|
||||||
|
elif api_base:
|
||||||
|
embed_url = api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HF_API_BASE", "")
|
||||||
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
else:
|
||||||
|
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||||
|
|
||||||
|
## ROUTING ##
|
||||||
|
if aembedding is True:
|
||||||
|
return self.aembedding(
|
||||||
|
input=input,
|
||||||
|
model_response=model_response,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
headers=headers,
|
||||||
|
api_base=embed_url, # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
data = self._transform_input(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
call_type="sync",
|
||||||
|
optional_params=optional_params,
|
||||||
|
embed_url=embed_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": embed_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
|
||||||
|
if "error" in embeddings:
|
||||||
|
raise HuggingFaceError(status_code=500, message=embeddings["error"])
|
||||||
|
|
||||||
|
## PROCESS RESPONSE ##
|
||||||
|
return self._process_embedding_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
589
litellm/llms/huggingface/embedding/transformation.py
Normal file
589
litellm/llms/huggingface/embedding/transformation.py
Normal file
|
@ -0,0 +1,589 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
convert_content_list_to_str,
|
||||||
|
)
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
|
from ..common_utils import HuggingFaceError, hf_task_list, hf_tasks, output_parser
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
|
tgi_models_cache = None
|
||||||
|
conv_models_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEmbeddingConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[
|
||||||
|
hf_tasks
|
||||||
|
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[
|
||||||
|
bool
|
||||||
|
] = False # by default don't return the input as part of the output
|
||||||
|
seed: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_n_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_special_options_params(self):
|
||||||
|
return ["use_cache", "wait_for_model"]
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> Dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params[
|
||||||
|
"do_sample"
|
||||||
|
] = True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_hf_api_key(self) -> Optional[str]:
|
||||||
|
return get_secret_str("HUGGINGFACE_API_KEY")
|
||||||
|
|
||||||
|
def read_tgi_conv_models(self):
|
||||||
|
try:
|
||||||
|
global tgi_models_cache, conv_models_cache
|
||||||
|
# Check if the cache is already populated
|
||||||
|
# so we don't keep on reading txt file if there are 1k requests
|
||||||
|
if (tgi_models_cache is not None) and (conv_models_cache is not None):
|
||||||
|
return tgi_models_cache, conv_models_cache
|
||||||
|
# If not, read the file and populate the cache
|
||||||
|
tgi_models = set()
|
||||||
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
script_directory = os.path.dirname(script_directory)
|
||||||
|
# Construct the file path relative to the script's directory
|
||||||
|
file_path = os.path.join(
|
||||||
|
script_directory,
|
||||||
|
"huggingface_llms_metadata",
|
||||||
|
"hf_text_generation_models.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
for line in file:
|
||||||
|
tgi_models.add(line.strip())
|
||||||
|
|
||||||
|
# Cache the set for future use
|
||||||
|
tgi_models_cache = tgi_models
|
||||||
|
|
||||||
|
# If not, read the file and populate the cache
|
||||||
|
file_path = os.path.join(
|
||||||
|
script_directory,
|
||||||
|
"huggingface_llms_metadata",
|
||||||
|
"hf_conversational_models.txt",
|
||||||
|
)
|
||||||
|
conv_models = set()
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
for line in file:
|
||||||
|
conv_models.add(line.strip())
|
||||||
|
# Cache the set for future use
|
||||||
|
conv_models_cache = conv_models
|
||||||
|
return tgi_models, conv_models
|
||||||
|
except Exception:
|
||||||
|
return set(), set()
|
||||||
|
|
||||||
|
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
|
||||||
|
# read text file, cast it to set
|
||||||
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
split_model = model.split("/", 1)
|
||||||
|
return split_model[0], split_model[1] # type: ignore
|
||||||
|
tgi_models, conversational_models = self.read_tgi_conv_models()
|
||||||
|
|
||||||
|
if model in tgi_models:
|
||||||
|
return "text-generation-inference", model
|
||||||
|
elif model in conversational_models:
|
||||||
|
return "conversational", model
|
||||||
|
elif "roneneldan/TinyStories" in model:
|
||||||
|
return "text-generation", model
|
||||||
|
else:
|
||||||
|
return "text-generation-inference", model # default to tgi
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
task = litellm_params.get("task", None)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.HuggingFaceEmbeddingConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
### MAP INPUT PARAMS
|
||||||
|
#### HANDLE SPECIAL PARAMS
|
||||||
|
special_params = self.get_special_options_params()
|
||||||
|
special_params_dict = {}
|
||||||
|
# Create a list of keys to pop after iteration
|
||||||
|
keys_to_pop = []
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_params:
|
||||||
|
special_params_dict[k] = v
|
||||||
|
keys_to_pop.append(k)
|
||||||
|
|
||||||
|
# Pop the keys from the dictionary after iteration
|
||||||
|
for k in keys_to_pop:
|
||||||
|
optional_params.pop(k)
|
||||||
|
if task == "conversational":
|
||||||
|
inference_params = deepcopy(optional_params)
|
||||||
|
inference_params.pop("details")
|
||||||
|
inference_params.pop("return_full_text")
|
||||||
|
past_user_inputs = []
|
||||||
|
generated_responses = []
|
||||||
|
text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if text != "":
|
||||||
|
past_user_inputs.append(text)
|
||||||
|
text = convert_content_list_to_str(message)
|
||||||
|
elif message["role"] == "assistant" or message["role"] == "system":
|
||||||
|
generated_responses.append(convert_content_list_to_str(message))
|
||||||
|
data = {
|
||||||
|
"inputs": {
|
||||||
|
"text": text,
|
||||||
|
"past_user_inputs": past_user_inputs,
|
||||||
|
"generated_responses": generated_responses,
|
||||||
|
},
|
||||||
|
"parameters": inference_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif task == "text-generation-inference":
|
||||||
|
# always send "details" and "return_full_text" as params
|
||||||
|
if model in litellm.custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get(
|
||||||
|
"final_prompt_value", ""
|
||||||
|
),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
data = {
|
||||||
|
"inputs": prompt, # type: ignore
|
||||||
|
"parameters": optional_params,
|
||||||
|
"stream": ( # type: ignore
|
||||||
|
True
|
||||||
|
if "stream" in optional_params
|
||||||
|
and isinstance(optional_params["stream"], bool)
|
||||||
|
and optional_params["stream"] is True # type: ignore
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Non TGI and Conversational llms
|
||||||
|
# We need this branch, it removes 'details' and 'return_full_text' from params
|
||||||
|
if model in litellm.custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get(
|
||||||
|
"final_prompt_value", ""
|
||||||
|
),
|
||||||
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
inference_params = deepcopy(optional_params)
|
||||||
|
inference_params.pop("details")
|
||||||
|
inference_params.pop("return_full_text")
|
||||||
|
data = {
|
||||||
|
"inputs": prompt, # type: ignore
|
||||||
|
}
|
||||||
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
|
True # type: ignore
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
### RE-ADD SPECIAL PARAMS
|
||||||
|
if len(special_params_dict.keys()) > 0:
|
||||||
|
data.update({"options": special_params_dict})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_api_base(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the API base for the Huggingface API.
|
||||||
|
|
||||||
|
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
||||||
|
"""
|
||||||
|
if "https" in model:
|
||||||
|
completion_url = model
|
||||||
|
elif api_base is not None:
|
||||||
|
completion_url = api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
completion_url = os.getenv("HF_API_BASE", "")
|
||||||
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
|
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
else:
|
||||||
|
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
|
||||||
|
return completion_url
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
default_headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
default_headers[
|
||||||
|
"Authorization"
|
||||||
|
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
||||||
|
|
||||||
|
headers = {**headers, **default_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return HuggingFaceError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_streamed_response_to_complete_response(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
model: str,
|
||||||
|
data: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
streamed_response = CustomStreamWrapper(
|
||||||
|
completion_stream=response.iter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="huggingface",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
for chunk in streamed_response:
|
||||||
|
content += chunk["choices"][0]["delta"]["content"]
|
||||||
|
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=completion_response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
return completion_response
|
||||||
|
|
||||||
|
def convert_to_model_response_object( # noqa: PLR0915
|
||||||
|
self,
|
||||||
|
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
task: Optional[hf_tasks],
|
||||||
|
optional_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
model: str,
|
||||||
|
):
|
||||||
|
if task is None:
|
||||||
|
task = "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
if task == "conversational":
|
||||||
|
if len(completion_response["generated_text"]) > 0: # type: ignore
|
||||||
|
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||||
|
"generated_text"
|
||||||
|
]
|
||||||
|
elif task == "text-generation-inference":
|
||||||
|
if (
|
||||||
|
not isinstance(completion_response, list)
|
||||||
|
or not isinstance(completion_response[0], dict)
|
||||||
|
or "generated_text" not in completion_response[0]
|
||||||
|
):
|
||||||
|
raise HuggingFaceError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"response is not in expected format - {completion_response}",
|
||||||
|
headers=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
|
model_response.choices[0].message.content = output_parser( # type: ignore
|
||||||
|
completion_response[0]["generated_text"]
|
||||||
|
)
|
||||||
|
## GETTING LOGPROBS + FINISH REASON
|
||||||
|
if (
|
||||||
|
"details" in completion_response[0]
|
||||||
|
and "tokens" in completion_response[0]["details"]
|
||||||
|
):
|
||||||
|
model_response.choices[0].finish_reason = completion_response[0][
|
||||||
|
"details"
|
||||||
|
]["finish_reason"]
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response[0]["details"]["tokens"]:
|
||||||
|
if token["logprob"] is not None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
||||||
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
if (
|
||||||
|
"details" in completion_response[0]
|
||||||
|
and "best_of_sequences" in completion_response[0]["details"]
|
||||||
|
):
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(
|
||||||
|
completion_response[0]["details"]["best_of_sequences"]
|
||||||
|
):
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in item["tokens"]:
|
||||||
|
if token["logprob"] is not None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
if len(item["generated_text"]) > 0:
|
||||||
|
message_obj = Message(
|
||||||
|
content=output_parser(item["generated_text"]),
|
||||||
|
logprobs=sum_logprob,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response.choices.extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response.choices[0].message.content = json.dumps( # type: ignore
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, list)
|
||||||
|
and len(completion_response[0]["generated_text"]) > 0
|
||||||
|
):
|
||||||
|
model_response.choices[0].message.content = output_parser( # type: ignore
|
||||||
|
completion_response[0]["generated_text"]
|
||||||
|
)
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = 0
|
||||||
|
try:
|
||||||
|
prompt_tokens = token_counter(model=model, messages=messages)
|
||||||
|
except Exception:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
if output_text is not None and len(output_text) > 0:
|
||||||
|
completion_tokens = 0
|
||||||
|
try:
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
) ##[TODO] use the llama2 tokenizer here
|
||||||
|
except Exception:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
model_response._hidden_params["original_response"] = completion_response
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
||||||
|
task = litellm_params.get("task", None)
|
||||||
|
is_streamed = False
|
||||||
|
if (
|
||||||
|
raw_response.__dict__["headers"].get("Content-Type", "")
|
||||||
|
== "text/event-stream"
|
||||||
|
):
|
||||||
|
is_streamed = True
|
||||||
|
|
||||||
|
# iterate over the complete streamed response, and return the final answer
|
||||||
|
if is_streamed:
|
||||||
|
completion_response = self._convert_streamed_response_to_complete_response(
|
||||||
|
response=raw_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
data=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
if isinstance(completion_response, dict):
|
||||||
|
completion_response = [completion_response]
|
||||||
|
except Exception:
|
||||||
|
raise HuggingFaceError(
|
||||||
|
message=f"Original Response received: {raw_response.text}",
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(completion_response, dict) and "error" in completion_response:
|
||||||
|
raise HuggingFaceError(
|
||||||
|
message=completion_response["error"], # type: ignore
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
return self.convert_to_model_response_object(
|
||||||
|
completion_response=completion_response,
|
||||||
|
model_response=model_response,
|
||||||
|
task=task if task is not None and task in hf_task_list else None,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
)
|
|
@ -34,7 +34,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||||
return api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_tool_response_to_message(
|
def _json_mode_convert_tool_response_to_message(
|
||||||
message: ChatCompletionAssistantMessage, json_mode: bool
|
message: ChatCompletionAssistantMessage, json_mode: bool
|
||||||
) -> ChatCompletionAssistantMessage:
|
) -> ChatCompletionAssistantMessage:
|
||||||
"""
|
"""
|
||||||
|
@ -88,8 +88,10 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||||
|
|
||||||
if json_mode:
|
if json_mode:
|
||||||
for choice in response_json["choices"]:
|
for choice in response_json["choices"]:
|
||||||
message = OpenAILikeChatConfig._convert_tool_response_to_message(
|
message = (
|
||||||
choice.get("message"), json_mode
|
OpenAILikeChatConfig._json_mode_convert_tool_response_to_message(
|
||||||
|
choice.get("message"), json_mode
|
||||||
|
)
|
||||||
)
|
)
|
||||||
choice["message"] = message
|
choice["message"] = message
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,13 @@ Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
|
||||||
Docs: https://openrouter.ai/docs/parameters
|
Docs: https://openrouter.ai/docs/parameters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Iterator, Optional, Union
|
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
||||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||||
|
|
||||||
|
@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig):
|
||||||
] = extra_body # openai client supports `extra_body` param
|
] = extra_body # openai client supports `extra_body` param
|
||||||
return mapped_openai_params
|
return mapped_openai_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Transform the overall request to be sent to the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The transformed request. Sent as the body of the API call.
|
||||||
|
"""
|
||||||
|
extra_body = optional_params.pop("extra_body", {})
|
||||||
|
response = super().transform_request(
|
||||||
|
model, messages, optional_params, litellm_params, headers
|
||||||
|
)
|
||||||
|
response.update(extra_body)
|
||||||
|
return response
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from httpx import Headers, Response
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.constants import DEFAULT_MAX_TOKENS
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
@ -27,7 +28,7 @@ class PredibaseConfig(BaseConfig):
|
||||||
decoder_input_details: Optional[bool] = None
|
decoder_input_details: Optional[bool] = None
|
||||||
details: bool = True # enables returning logprobs + best of
|
details: bool = True # enables returning logprobs + best of
|
||||||
max_new_tokens: int = (
|
max_new_tokens: int = (
|
||||||
256 # openai default - requests hang if max_new_tokens not given
|
DEFAULT_MAX_TOKENS # openai default - requests hang if max_new_tokens not given
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
return_full_text: Optional[
|
return_full_text: Optional[
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue