Merge branch 'BerriAI:main' into main

This commit is contained in:
Hannes Burrichter 2024-05-14 13:31:07 +02:00 committed by GitHub
commit 1bd6a1ba05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
142 changed files with 6672 additions and 1270 deletions

10
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,10 @@
# Add the commit hash of any commit you want to ignore in `git blame` here.
# One commit hash per line.
#
# The GitHub Blame UI will use this file automatically!
#
# Run this command to always ignore formatting commits in `git blame`
# git config blame.ignoreRevsFile .git-blame-ignore-revs
# Update pydantic code to fix warnings (GH-3600)
876840e9957bc7e9f7d6a2b58c4d7c53dad16481

View file

@ -1,6 +1,3 @@
<!-- This is just examples. You can remove all items if you want. -->
<!-- Please remove all comments. -->
## Title ## Title
<!-- e.g. "Implement user authentication feature" --> <!-- e.g. "Implement user authentication feature" -->
@ -18,7 +15,6 @@
🐛 Bug Fix 🐛 Bug Fix
🧹 Refactoring 🧹 Refactoring
📖 Documentation 📖 Documentation
💻 Development Environment
🚄 Infrastructure 🚄 Infrastructure
✅ Test ✅ Test
@ -26,22 +22,8 @@
<!-- List of changes --> <!-- List of changes -->
## Testing ## [REQUIRED] Testing - Attach a screenshot of any new tests passing locall
If UI changes, send a screenshot/GIF of working UI fixes
<!-- Test procedure --> <!-- Test procedure -->
## Notes
<!-- Test results -->
<!-- Points to note for the reviewer, consultation content, concerns -->
## Pre-Submission Checklist (optional but appreciated):
- [ ] I have included relevant documentation updates (stored in /docs/my-website)
## OS Tests (optional but appreciated):
- [ ] Tested on Windows
- [ ] Tested on MacOS
- [ ] Tested on Linux

187
cookbook/liteLLM_clarifai_Demo.ipynb vendored Normal file
View file

@ -0,0 +1,187 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LiteLLM Clarifai \n",
"This notebook walks you through on how to use liteLLM integration of Clarifai and call LLM model from clarifai with response in openAI output format."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-Requisites"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#install necessary packages\n",
"!pip install litellm\n",
"!pip install clarifai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To obtain Clarifai Personal Access Token follow the steps mentioned in the [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"## Set Clarifai Credentials\n",
"import os\n",
"os.environ[\"CLARIFAI_API_KEY\"]= \"YOUR_CLARIFAI_PAT\" # Clarifai PAT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mistral-large"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import litellm\n",
"\n",
"litellm.set_verbose=False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mistral large response : ModelResponse(id='chatcmpl-6eed494d-7ae2-4870-b9c2-6a64d50a6151', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\"In the grand tapestry of time, where tales unfold,\\nLies the chronicle of ages, a sight to behold.\\nA tale of empires rising, and kings of old,\\nOf civilizations lost, and stories untold.\\n\\nOnce upon a yesterday, in a time so vast,\\nHumans took their first steps, casting shadows in the past.\\nFrom the cradle of mankind, a journey they embarked,\\nThrough stone and bronze and iron, their skills they sharpened and marked.\\n\\nEgyptians built pyramids, reaching for the skies,\\nWhile Greeks sought wisdom, truth, in philosophies that lie.\\nRoman legions marched, their empire to expand,\\nAnd in the East, the Silk Road joined the world, hand in hand.\\n\\nThe Middle Ages came, with knights in shining armor,\\nFeudal lords and serfs, a time of both clamor and calm order.\\nThen Renaissance bloomed, like a flower in the sun,\\nA rebirth of art and science, a new age had begun.\\n\\nAcross the vast oceans, explorers sailed with courage bold,\\nDiscovering new lands, stories of adventure, untold.\\nIndustrial Revolution churned, progress in its wake,\\nMachines and factories, a whole new world to make.\\n\\nTwo World Wars raged, a testament to man's strife,\\nYet from the ashes rose hope, a renewed will for life.\\nInto the modern era, technology took flight,\\nConnecting every corner, bathed in digital light.\\n\\nHistory, a symphony, a melody of time,\\nA testament to human will, resilience so sublime.\\nIn every page, a lesson, in every tale, a guide,\\nFor understanding our past, shapes our future's tide.\", role='assistant'))], created=1713896412, model='https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=13, completion_tokens=338, total_tokens=351))\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response=completion(\n",
" model=\"clarifai/mistralai.completion.mistral-large\",\n",
" messages=messages,\n",
" )\n",
"\n",
"print(f\"Mistral large response : {response}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Claude-2.1 "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Claude-2.1 response : ModelResponse(id='chatcmpl-d126c919-4db4-4aa3-ac8f-7edea41e0b93', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\" Here's a poem I wrote about history:\\n\\nThe Tides of Time\\n\\nThe tides of time ebb and flow,\\nCarrying stories of long ago.\\nFigures and events come into light,\\nShaping the future with all their might.\\n\\nKingdoms rise, empires fall, \\nLeaving traces that echo down every hall.\\nRevolutions bring change with a fiery glow,\\nToppling structures from long ago.\\n\\nExplorers traverse each ocean and land,\\nSeeking treasures they don't understand.\\nWhile artists and writers try to make their mark,\\nHoping their works shine bright in the dark.\\n\\nThe cycle repeats again and again,\\nAs humanity struggles to learn from its pain.\\nThough the players may change on history's stage,\\nThe themes stay the same from age to age.\\n\\nWar and peace, life and death,\\nLove and strife with every breath.\\nThe tides of time continue their dance,\\nAs we join in, by luck or by chance.\\n\\nSo we study the past to light the way forward, \\nHeeding warnings from stories told and heard.\\nThe future unfolds from this unending flow -\\nWhere the tides of time ultimately go.\", role='assistant'))], created=1713896579, model='https://api.clarifai.com/v2/users/anthropic/apps/completion/models/claude-2_1/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=232, total_tokens=244))\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response=completion(\n",
" model=\"clarifai/anthropic.completion.claude-2_1\",\n",
" messages=messages,\n",
" )\n",
"\n",
"print(f\"Claude-2.1 response : {response}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### OpenAI GPT-4 (Streaming)\n",
"Though clarifai doesn't support streaming, still you can call stream and get the response in standard StreamResponse format of liteLLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"In the quiet corners of time's grand hall,\\nLies the tale of rise and fall.\\nFrom ancient ruins to modern sprawl,\\nHistory, the greatest story of them all.\\n\\nEmpires have risen, empires have decayed,\\nThrough the eons, memories have stayed.\\nIn the book of time, history is laid,\\nA tapestry of events, meticulously displayed.\\n\\nThe pyramids of Egypt, standing tall,\\nThe Roman Empire's mighty sprawl.\\nFrom Alexander's conquest, to the Berlin Wall,\\nHistory, a silent witness to it all.\\n\\nIn the shadow of the past we tread,\\nWhere once kings and prophets led.\\nTheir stories in our hearts are spread,\\nEchoes of their words, in our minds are read.\\n\\nBattles fought and victories won,\\nActs of courage under the sun.\\nTales of love, of deeds done,\\nIn history's grand book, they all run.\\n\\nHeroes born, legends made,\\nIn the annals of time, they'll never fade.\\nTheir triumphs and failures all displayed,\\nIn the eternal march of history's parade.\\n\\nThe ink of the past is forever dry,\\nBut its lessons, we cannot deny.\\nIn its stories, truths lie,\\nIn its wisdom, we rely.\\n\\nHistory, a mirror to our past,\\nA guide for the future vast.\\nThrough its lens, we're ever cast,\\nIn the drama of life, forever vast.\", role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response = completion(\n",
" model=\"clarifai/openai.chat-completion.GPT-4\",\n",
" messages=messages,\n",
" stream=True,\n",
" api_key = \"c75cc032415e45368be331fdd2c06db0\")\n",
"\n",
"for chunk in response:\n",
" print(chunk)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -4,6 +4,12 @@ LiteLLM allows you to:
* Send 1 completion call to many models: Return Fastest Response * Send 1 completion call to many models: Return Fastest Response
* Send 1 completion call to many models: Return All Responses * Send 1 completion call to many models: Return All Responses
:::info
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
:::
## Send multiple completion calls to 1 model ## Send multiple completion calls to 1 model
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call. In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.

View file

@ -37,11 +37,12 @@ print(response) # ["max_tokens", "tools", "tool_choice", "stream"]
This is a list of openai params we translate across providers. This is a list of openai params we translate across providers.
This list is constantly being updated. Use `litellm.get_supported_openai_params()` for an updated list of params for each model + provider
| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | | Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--| |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ | ✅ |
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | ✅ | ✅ | ✅ | ✅ |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |

View file

@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
## Custom mapping list ## Custom mapping list
Base case - we return the original exception. Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError | | custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------| |----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| watsonx | | | | | | | |✓| | | |
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |

View file

@ -137,6 +137,7 @@ response = completion(
"existing_trace_id": "trace-id22", "existing_trace_id": "trace-id22",
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata "trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value "update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
"debug_langfuse": True, # Will log the exact metadata sent to litellm for the trace/generation as `metadata_passed_to_litellm`
}, },
) )
@ -214,8 +215,20 @@ chat(messages)
## Redacting Messages, Response Content from Langfuse Logging ## Redacting Messages, Response Content from Langfuse Logging
### Redact Messages and Responses from all Langfuse Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged. Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
### Redact Messages and Responses from specific Langfuse Logging
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
Setting `mask_input` to `True` will mask the input from being logged for this call
Setting `mask_output` to `True` will make the output from being logged for this call.
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
## Troubleshooting & Errors ## Troubleshooting & Errors
### Data not getting logged to Langfuse ? ### Data not getting logged to Langfuse ?
- Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse - Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse

View file

@ -0,0 +1,177 @@
# Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
## Pre-Requisites
`pip install clarifai`
`pip install litellm`
## Required Environment Variables
To obtain your Clarifai Personal access token follow this [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/). Optionally the PAT can also be passed in `completion` function.
```python
os.environ["CALRIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
```
## Usage
```python
import os
from litellm import completion
os.environ["CLARIFAI_API_KEY"] = ""
response = completion(
model="clarifai/mistralai.completion.mistral-large",
messages=[{ "content": "Tell me a joke about physics?","role": "user"}]
)
```
**Output**
```json
{
"id": "chatcmpl-572701ee-9ab2-411c-ac75-46c1ba18e781",
"choices": [
{
"finish_reason": "stop",
"index": 1,
"message": {
"content": "Sure, here's a physics joke for you:\n\nWhy can't you trust an atom?\n\nBecause they make up everything!",
"role": "assistant"
}
}
],
"created": 1714410197,
"model": "https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"prompt_tokens": 14,
"completion_tokens": 24,
"total_tokens": 38
}
}
```
## Clarifai models
liteLLM supports non-streaming requests to all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24)
Example Usage - Note: liteLLM supports all models deployed on Clarifai
## Llama LLMs
| Model Name | Function Call |
---------------------------|---------------------------------|
| clarifai/meta.Llama-2.llama2-7b-chat | `completion('clarifai/meta.Llama-2.llama2-7b-chat', messages)`
| clarifai/meta.Llama-2.llama2-13b-chat | `completion('clarifai/meta.Llama-2.llama2-13b-chat', messages)`
| clarifai/meta.Llama-2.llama2-70b-chat | `completion('clarifai/meta.Llama-2.llama2-70b-chat', messages)` |
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
## Mistal LLMs
| Model Name | Function Call |
|---------------------------------------------|------------------------------------------------------------------------|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
| clarifai/mistralai.completion.mistral-large | `completion('clarifai/mistralai.completion.mistral-large', messages)` |
| clarifai/mistralai.completion.mistral-medium | `completion('clarifai/mistralai.completion.mistral-medium', messages)` |
| clarifai/mistralai.completion.mistral-small | `completion('clarifai/mistralai.completion.mistral-small', messages)` |
| clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1 | `completion('clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1', messages)`
| clarifai/mistralai.completion.mistral-7B-OpenOrca | `completion('clarifai/mistralai.completion.mistral-7B-OpenOrca', messages)` |
| clarifai/mistralai.completion.openHermes-2-mistral-7B | `completion('clarifai/mistralai.completion.openHermes-2-mistral-7B', messages)` |
## Jurassic LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/ai21.complete.Jurassic2-Grande | `completion('clarifai/ai21.complete.Jurassic2-Grande', messages)` |
| clarifai/ai21.complete.Jurassic2-Grande-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Grande-Instruct', messages)` |
| clarifai/ai21.complete.Jurassic2-Jumbo-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Jumbo-Instruct', messages)` |
| clarifai/ai21.complete.Jurassic2-Jumbo | `completion('clarifai/ai21.complete.Jurassic2-Jumbo', messages)` |
| clarifai/ai21.complete.Jurassic2-Large | `completion('clarifai/ai21.complete.Jurassic2-Large', messages)` |
## Wizard LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/wizardlm.generate.wizardCoder-Python-34B | `completion('clarifai/wizardlm.generate.wizardCoder-Python-34B', messages)` |
| clarifai/wizardlm.generate.wizardLM-70B | `completion('clarifai/wizardlm.generate.wizardLM-70B', messages)` |
| clarifai/wizardlm.generate.wizardLM-13B | `completion('clarifai/wizardlm.generate.wizardLM-13B', messages)` |
| clarifai/wizardlm.generate.wizardCoder-15B | `completion('clarifai/wizardlm.generate.wizardCoder-15B', messages)` |
## Anthropic models
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/anthropic.completion.claude-v1 | `completion('clarifai/anthropic.completion.claude-v1', messages)` |
| clarifai/anthropic.completion.claude-instant-1_2 | `completion('clarifai/anthropic.completion.claude-instant-1_2', messages)` |
| clarifai/anthropic.completion.claude-instant | `completion('clarifai/anthropic.completion.claude-instant', messages)` |
| clarifai/anthropic.completion.claude-v2 | `completion('clarifai/anthropic.completion.claude-v2', messages)` |
| clarifai/anthropic.completion.claude-2_1 | `completion('clarifai/anthropic.completion.claude-2_1', messages)` |
| clarifai/anthropic.completion.claude-3-opus | `completion('clarifai/anthropic.completion.claude-3-opus', messages)` |
| clarifai/anthropic.completion.claude-3-sonnet | `completion('clarifai/anthropic.completion.claude-3-sonnet', messages)` |
## OpenAI GPT LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/openai.chat-completion.GPT-4 | `completion('clarifai/openai.chat-completion.GPT-4', messages)` |
| clarifai/openai.chat-completion.GPT-3_5-turbo | `completion('clarifai/openai.chat-completion.GPT-3_5-turbo', messages)` |
| clarifai/openai.chat-completion.gpt-4-turbo | `completion('clarifai/openai.chat-completion.gpt-4-turbo', messages)` |
| clarifai/openai.completion.gpt-3_5-turbo-instruct | `completion('clarifai/openai.completion.gpt-3_5-turbo-instruct', messages)` |
## GCP LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/gcp.generate.gemini-1_5-pro | `completion('clarifai/gcp.generate.gemini-1_5-pro', messages)` |
| clarifai/gcp.generate.imagen-2 | `completion('clarifai/gcp.generate.imagen-2', messages)` |
| clarifai/gcp.generate.code-gecko | `completion('clarifai/gcp.generate.code-gecko', messages)` |
| clarifai/gcp.generate.code-bison | `completion('clarifai/gcp.generate.code-bison', messages)` |
| clarifai/gcp.generate.text-bison | `completion('clarifai/gcp.generate.text-bison', messages)` |
| clarifai/gcp.generate.gemma-2b-it | `completion('clarifai/gcp.generate.gemma-2b-it', messages)` |
| clarifai/gcp.generate.gemma-7b-it | `completion('clarifai/gcp.generate.gemma-7b-it', messages)` |
| clarifai/gcp.generate.gemini-pro | `completion('clarifai/gcp.generate.gemini-pro', messages)` |
| clarifai/gcp.generate.gemma-1_1-7b-it | `completion('clarifai/gcp.generate.gemma-1_1-7b-it', messages)` |
## Cohere LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/cohere.generate.cohere-generate-command | `completion('clarifai/cohere.generate.cohere-generate-command', messages)` |
clarifai/cohere.generate.command-r-plus' | `completion('clarifai/clarifai/cohere.generate.command-r-plus', messages)`|
## Databricks LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/databricks.drbx.dbrx-instruct | `completion('clarifai/databricks.drbx.dbrx-instruct', messages)` |
| clarifai/databricks.Dolly-v2.dolly-v2-12b | `completion('clarifai/databricks.Dolly-v2.dolly-v2-12b', messages)`|
## Microsoft LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/microsoft.text-generation.phi-2 | `completion('clarifai/microsoft.text-generation.phi-2', messages)` |
| clarifai/microsoft.text-generation.phi-1_5 | `completion('clarifai/microsoft.text-generation.phi-1_5', messages)`|
## Salesforce models
| Model Name | Function Call |
|-----------------------------------------------------------|-------------------------------------------------------------------------------|
| clarifai/salesforce.blip.general-english-image-caption-blip-2 | `completion('clarifai/salesforce.blip.general-english-image-caption-blip-2', messages)` |
| clarifai/salesforce.xgen.xgen-7b-8k-instruct | `completion('clarifai/salesforce.xgen.xgen-7b-8k-instruct', messages)` |
## Other Top performing LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/deci.decilm.deciLM-7B-instruct | `completion('clarifai/deci.decilm.deciLM-7B-instruct', messages)` |
| clarifai/upstage.solar.solar-10_7b-instruct | `completion('clarifai/upstage.solar.solar-10_7b-instruct', messages)` |
| clarifai/openchat.openchat.openchat-3_5-1210 | `completion('clarifai/openchat.openchat.openchat-3_5-1210', messages)` |
| clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B | `completion('clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B', messages)` |
| clarifai/fblgit.una-cybertron.una-cybertron-7b-v2 | `completion('clarifai/fblgit.una-cybertron.una-cybertron-7b-v2', messages)` |
| clarifai/tiiuae.falcon.falcon-40b-instruct | `completion('clarifai/tiiuae.falcon.falcon-40b-instruct', messages)` |
| clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat | `completion('clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat', messages)` |
| clarifai/bigcode.code.StarCoder | `completion('clarifai/bigcode.code.StarCoder', messages)` |
| clarifai/mosaicml.mpt.mpt-7b-instruct | `completion('clarifai/mosaicml.mpt.mpt-7b-instruct', messages)` |

View file

@ -20,7 +20,7 @@ os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call # openai call
response = completion( response = completion(
model = "gpt-3.5-turbo", model = "gpt-4o",
messages=[{ "content": "Hello, how are you?","role": "user"}] messages=[{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
@ -163,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
| gpt-4o-2024-05-13 | `response = completion(model="gpt-4o-2024-05-13", messages=messages)` |
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` | | gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` | | gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` | | gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |

View file

@ -1,14 +1,22 @@
# 🚨 Alerting # 🚨 Alerting
Get alerts for: Get alerts for:
- Hanging LLM api calls - Hanging LLM api calls
- Failed LLM api calls - Failed LLM api calls
- Slow LLM api calls - Slow LLM api calls
- Budget Tracking per key/user: - Budget Tracking per key/user:
- When a User/Key crosses their Budget - When a User/Key crosses their Budget
- When a User/Key is 15% away from crossing their Budget - When a User/Key is 15% away from crossing their Budget
- Spend Reports - Weekly & Monthly spend per Team, Tag
- Failed db read/writes - Failed db read/writes
As a bonus, you can also get "daily reports" posted to your slack channel.
These reports contain key metrics like:
- Top 5 deployments with most failed requests
- Top 5 slowest deployments
## Quick Start ## Quick Start
Set up a slack alert channel to receive alerts from proxy. Set up a slack alert channel to receive alerts from proxy.
@ -20,7 +28,8 @@ Get a slack webhook url from https://api.slack.com/messaging/webhooks
### Step 2: Update config.yaml ### Step 2: Update config.yaml
Let's save a bad key to our proxy - Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
- Just for testing purposes, let's save a bad key to our proxy.
```yaml ```yaml
model_list: model_list:
@ -33,13 +42,11 @@ general_settings:
alerting: ["slack"] alerting: ["slack"]
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+ alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
environment_variables:
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
``` ```
Set `SLACK_WEBHOOK_URL` in your proxy env
```shell
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
```
### Step 3: Start proxy ### Step 3: Start proxy

View file

@ -1,8 +1,136 @@
# Cost Tracking - Azure import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💸 Spend Tracking
Track spend for keys, users, and teams across 100+ LLMs.
## Getting Spend Reports - To Charge Other Teams, API Keys
Use the `/global/spend/report` endpoint to get daily spend per team, with a breakdown of spend per API Key, Model
### Example Request
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2023-04-01&end_date=2024-06-30' \
-H 'Authorization: Bearer sk-1234'
```
### Example Response
<Tabs>
<TabItem value="response" label="Expected Response">
```shell
[
{
"group_by_day": "2024-04-30T00:00:00+00:00",
"teams": [
{
"team_name": "Prod Team",
"total_spend": 0.0015265,
"metadata": [ # see the spend by unique(key + model)
{
"model": "gpt-4",
"spend": 0.00123,
"total_tokens": 28,
"api_key": "88dc28.." # the hashed api key
},
{
"model": "gpt-4",
"spend": 0.00123,
"total_tokens": 28,
"api_key": "a73dc2.." # the hashed api key
},
{
"model": "chatgpt-v-2",
"spend": 0.000214,
"total_tokens": 122,
"api_key": "898c28.." # the hashed api key
},
{
"model": "gpt-3.5-turbo",
"spend": 0.0000825,
"total_tokens": 85,
"api_key": "84dc28.." # the hashed api key
}
]
}
]
}
]
```
</TabItem>
<TabItem value="py-script" label="Script to Parse Response (Python)">
```python
import requests
url = 'http://localhost:4000/global/spend/report'
params = {
'start_date': '2023-04-01',
'end_date': '2024-06-30'
}
headers = {
'Authorization': 'Bearer sk-1234'
}
# Make the GET request
response = requests.get(url, headers=headers, params=params)
spend_report = response.json()
for row in spend_report:
date = row["group_by_day"]
teams = row["teams"]
for team in teams:
team_name = team["team_name"]
total_spend = team["total_spend"]
metadata = team["metadata"]
print(f"Date: {date}")
print(f"Team: {team_name}")
print(f"Total Spend: {total_spend}")
print("Metadata: ", metadata)
print()
```
Output from script
```shell
# Date: 2024-05-11T00:00:00+00:00
# Team: local_test_team
# Total Spend: 0.003675099999999999
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.003675099999999999, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 3105}]
# Date: 2024-05-13T00:00:00+00:00
# Team: Unassigned Team
# Total Spend: 3.4e-05
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 3.4e-05, 'api_key': '9569d13c9777dba68096dea49b0b03e0aaf4d2b65d4030eda9e8a2733c3cd6e0', 'total_tokens': 50}]
# Date: 2024-05-13T00:00:00+00:00
# Team: central
# Total Spend: 0.000684
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.000684, 'api_key': '0323facdf3af551594017b9ef162434a9b9a8ca1bbd9ccbd9d6ce173b1015605', 'total_tokens': 498}]
# Date: 2024-05-13T00:00:00+00:00
# Team: local_test_team
# Total Spend: 0.0005715000000000001
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.0005715000000000001, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 423}]
```
</TabItem>
</Tabs>
## Spend Tracking for Azure
Set base model for cost tracking azure image-gen call Set base model for cost tracking azure image-gen call
## Image Generation ### Image Generation
```yaml ```yaml
model_list: model_list:
@ -17,7 +145,7 @@ model_list:
mode: image_generation mode: image_generation
``` ```
## Chat Completions / Embeddings ### Chat Completions / Embeddings
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking **Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking

View file

@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina # 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Logging to Sentry](#logging-proxy-inputoutput---sentry) - [Logging to Sentry](#logging-proxy-inputoutput---sentry)
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry) - [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
- [Logging to Athina](#logging-proxy-inputoutput-athina) - [Logging to Athina](#logging-proxy-inputoutput-athina)
- [(BETA) Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
## Custom Callback Class [Async] ## Custom Callback Class [Async]
Use this when you want to run custom callbacks in `python` Use this when you want to run custom callbacks in `python`
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
] ]
}' }'
``` ```
## (BETA) Moderation with Azure Content Safety
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
**Step 0** Deploy Azure Content Safety
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
**Step 1** Set Athina API key
```shell
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hi, how are you?"
}
]
}'
```
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
The details of the response will describe :
- The `source` : input text or llm generated text
- The `category` : the category of the content that triggered the moderation
- The `severity` : the severity from 0 to 10
**Step 4**: Customizing Azure Content Safety Thresholds
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
thresholds:
Hate: 6
SelfHarm: 8
Sexual: 6
Violence: 4
```
:::info
`thresholds` are not required by default, but you can tune the values to your needs.
Default values is `4` for all categories
:::

View file

@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}' }'
``` ```
## Advanced - Context Window Fallbacks ## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**. **Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
@ -232,16 +232,16 @@ model_list:
- model_name: gpt-3.5-turbo-small - model_name: gpt-3.5-turbo-small
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview" api_version: "2023-07-01-preview"
model_info: model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo-large - model_name: gpt-3.5-turbo-large
litellm_params: litellm_params:
model: gpt-3.5-turbo-1106 model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: claude-opus - model_name: claude-opus
litellm_params: litellm_params:
@ -287,6 +287,69 @@ print(response)
</Tabs> </Tabs>
## Advanced - EU-Region Filtering (Pre-Call Checks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
Set 'region_name' of deployment.
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
**1. Set Config**
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
region_name: "eu" # 👈 SET EU-REGION
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
- model_name: gemini-pro
litellm_params:
model: vertex_ai/gemini-pro-1.5
vertex_project: adroit-crow-1234
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages = [{"role": "user", "content": "Who was Alexander?"}]
)
print(response)
print(f"response.headers.get('x-litellm-model-api-base')")
```
## Advanced - Custom Timeouts, Stream Timeouts - Per Model ## Advanced - Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params` For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml ```yaml

View file

@ -110,7 +110,7 @@ general_settings:
admin_jwt_scope: "litellm-proxy-admin" admin_jwt_scope: "litellm-proxy-admin"
``` ```
## Advanced - Spend Tracking (User / Team / Org) ## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
Set the field in the jwt token, which corresponds to a litellm user / team / org. Set the field in the jwt token, which corresponds to a litellm user / team / org.
@ -123,6 +123,7 @@ general_settings:
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
``` ```
Expected JWT: Expected JWT:
@ -131,7 +132,7 @@ Expected JWT:
{ {
"client_id": "my-unique-team", "client_id": "my-unique-team",
"sub": "my-unique-user", "sub": "my-unique-user",
"org_id": "my-unique-org" "org_id": "my-unique-org",
} }
``` ```

View file

@ -365,6 +365,188 @@ curl --location 'http://0.0.0.0:4000/moderations' \
## Advanced ## Advanced
### (BETA) Batch Completions - pass multiple models
Use this when you want to send 1 request to N Models
#### Expected Request Format
Pass model as a string of comma separated value of models. Example `"model"="llama3,gpt-3.5-turbo"`
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
- `model_name="llama3"`
- `model_name="gpt-3.5-turbo"`
<Tabs>
<TabItem value="openai-py" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.chat.completions.create(
model="gpt-3.5-turbo,llama3",
messages=[
{"role": "user", "content": "this is a test request, write a short poem"}
],
)
print(response)
```
#### Expected Response Format
Get a list of responses when `model` is passed as a list
```python
[
ChatCompletion(
id='chatcmpl-9NoYhS2G0fswot0b6QpoQgmRQMaIf',
choices=[
Choice(
finish_reason='stop',
index=0,
logprobs=None,
message=ChatCompletionMessage(
content='In the depths of my soul, a spark ignites\nA light that shines so pure and bright\nIt dances and leaps, refusing to die\nA flame of hope that reaches the sky\n\nIt warms my heart and fills me with bliss\nA reminder that in darkness, there is light to kiss\nSo I hold onto this fire, this guiding light\nAnd let it lead me through the darkest night.',
role='assistant',
function_call=None,
tool_calls=None
)
)
],
created=1715462919,
model='gpt-3.5-turbo-0125',
object='chat.completion',
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=83,
prompt_tokens=17,
total_tokens=100
)
),
ChatCompletion(
id='chatcmpl-4ac3e982-da4e-486d-bddb-ed1d5cb9c03c',
choices=[
Choice(
finish_reason='stop',
index=0,
logprobs=None,
message=ChatCompletionMessage(
content="A test request, and I'm delighted!\nHere's a short poem, just for you:\n\nMoonbeams dance upon the sea,\nA path of light, for you to see.\nThe stars up high, a twinkling show,\nA night of wonder, for all to know.\n\nThe world is quiet, save the night,\nA peaceful hush, a gentle light.\nThe world is full, of beauty rare,\nA treasure trove, beyond compare.\n\nI hope you enjoyed this little test,\nA poem born, of whimsy and jest.\nLet me know, if there's anything else!",
role='assistant',
function_call=None,
tool_calls=None
)
)
],
created=1715462919,
model='groq/llama3-8b-8192',
object='chat.completion',
system_fingerprint='fp_a2c8d063cb',
usage=CompletionUsage(
completion_tokens=120,
prompt_tokens=20,
total_tokens=140
)
)
]
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3,gpt-3.5-turbo",
"max_tokens": 10,
"user": "litellm2",
"messages": [
{
"role": "user",
"content": "is litellm getting better"
}
]
}'
```
#### Expected Response Format
Get a list of responses when `model` is passed as a list
```json
[
{
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
"role": "assistant"
}
}
],
"created": 1715459876,
"model": "groq/llama3-8b-8192",
"object": "chat.completion",
"system_fingerprint": "fp_179b0f92c9",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 12,
"total_tokens": 22
}
},
{
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "TES4 could refer to The Elder Scrolls IV:",
"role": "assistant"
}
}
],
"created": 1715459877,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"completion_tokens": 10,
"prompt_tokens": 9,
"total_tokens": 19
}
}
]
```
</TabItem>
</Tabs>
### Pass User LLM API Keys, Fallbacks ### Pass User LLM API Keys, Fallbacks
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests

View file

@ -653,7 +653,9 @@ from litellm import Router
model_list = [{...}] model_list = [{...}]
router = Router(model_list=model_list, router = Router(model_list=model_list,
allowed_fails=1) # cooldown model if it fails > 1 call in a minute. allowed_fails=1, # cooldown model if it fails > 1 call in a minute.
cooldown_time=100 # cooldown the deployment for 100 seconds if it num_fails > allowed_fails
)
user_message = "Hello, whats the weather in San Francisco??" user_message = "Hello, whats the weather in San Francisco??"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
@ -770,6 +772,8 @@ If the error is a context window exceeded error, fall back to a larger model gro
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc. Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad.
```python ```python
from litellm import Router from litellm import Router
@ -830,6 +834,7 @@ model_list = [
router = Router(model_list=model_list, router = Router(model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
default_fallbacks=["gpt-3.5-turbo-16k"],
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
set_verbose=True) set_verbose=True)
@ -879,13 +884,11 @@ router = Router(model_list: Optional[list] = None,
cache_responses=True) cache_responses=True)
``` ```
## Pre-Call Checks (Context Window) ## Pre-Call Checks (Context Window, EU-Regions)
Enable pre-call checks to filter out: Enable pre-call checks to filter out:
1. deployments with context window limit < messages for a call. 1. deployments with context window limit < messages for a call.
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[ 2. deployments outside of eu-region
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
])`)
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -900,10 +903,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
**2. Set Model List** **2. Set Model List**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`. For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
<Tabs> For 'eu-region' filtering, Set 'region_name' of deployment.
<TabItem value="same-group" label="Same Group">
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
```python ```python
model_list = [ model_list = [
@ -914,10 +921,9 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
}, "region_name": "eu" # 👈 SET 'EU' REGION NAME
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL "base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
} },
}, },
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -926,54 +932,26 @@ model_list = [
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
}, },
{
"model_name": "gemini-pro",
"litellm_params: {
"model": "vertex_ai/gemini-pro-1.5",
"vertex_project": "adroit-crow-1234",
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
}
}
] ]
router = Router(model_list=model_list, enable_pre_call_checks=True) router = Router(model_list=model_list, enable_pre_call_checks=True)
``` ```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
```python
model_list = [
{
"model_name": "gpt-3.5-turbo-small", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
"model_name": "gpt-3.5-turbo-large", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-opus",
"litellm_params": { call
"model": "claude-3-opus-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
```
</TabItem>
</Tabs>
**3. Test it!** **3. Test it!**
<Tabs>
<TabItem value="context-window-check" label="Context Window Check">
```python ```python
""" """
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k) - Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
@ -983,7 +961,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
from litellm import Router from litellm import Router
import os import os
try:
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -992,6 +969,7 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
}, },
"model_info": { "model_info": {
"base_model": "azure/gpt-35-turbo", "base_model": "azure/gpt-35-turbo",
@ -1021,6 +999,59 @@ response = router.completion(
print(f"response: {response}") print(f"response: {response}")
``` ```
</TabItem> </TabItem>
<TabItem value="eu-region-check" label="EU Region Check">
```python
"""
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
- Make a call
- Assert it picks the eu-region model
"""
from litellm import Router
import os
model_list = [
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"region_name": "eu"
},
"model_info": {
"id": "1"
}
},
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {
"id": "2"
}
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Who was Alexander?"}],
)
print(f"response: {response}")
print(f"response id: {response._hidden_params['model_id']}")
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="proxy" label="Proxy"> <TabItem value="proxy" label="Proxy">
:::info :::info
@ -1283,10 +1314,11 @@ def __init__(
num_retries: int = 0, num_retries: int = 0,
timeout: Optional[float] = None, timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create default_litellm_params={}, # default params for Router.chat.completion.create
fallbacks: List = [], fallbacks: Optional[List] = None,
default_fallbacks: Optional[List] = None
allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
context_window_fallbacks: List = [], context_window_fallbacks: Optional[List] = None,
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
retry_after: int = 0, # (min) time to wait before retrying a failed request retry_after: int = 0, # (min) time to wait before retrying a failed request
routing_strategy: Literal[ routing_strategy: Literal[

View file

@ -39,6 +39,7 @@ const sidebars = {
"proxy/demo", "proxy/demo",
"proxy/configs", "proxy/configs",
"proxy/reliability", "proxy/reliability",
"proxy/cost_tracking",
"proxy/users", "proxy/users",
"proxy/user_keys", "proxy/user_keys",
"proxy/enterprise", "proxy/enterprise",
@ -52,7 +53,6 @@ const sidebars = {
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/customer_routing", "proxy/customer_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking",
"proxy/token_auth", "proxy/token_auth",
{ {
type: "category", type: "category",

View file

@ -10,7 +10,6 @@ from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -19,8 +18,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -1,6 +1,7 @@
# Enterprise Proxy Util Endpoints # Enterprise Proxy Util Endpoints
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
import collections import collections
from datetime import datetime
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
@ -18,26 +19,33 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
return response return response
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def ui_get_spend_by_tags(start_date: str, end_date: str, prisma_client):
response = await prisma_client.db.query_raw(
""" sql_query = """
SELECT SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag, jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date, DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count, COUNT(*) AS log_count,
SUM(spend) AS total_spend SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s FROM "LiteLLM_SpendLogs" s
WHERE s."startTime" >= current_date - interval '30 days' WHERE
DATE(s."startTime") >= $1::date
AND DATE(s."startTime") <= $2::date
GROUP BY individual_request_tag, spend_date GROUP BY individual_request_tag, spend_date
ORDER BY spend_date; ORDER BY spend_date
""" LIMIT 100;
"""
response = await prisma_client.db.query_raw(
sql_query,
start_date,
end_date,
) )
# print("tags - spend") # print("tags - spend")
# print(response) # print(response)
# Bar Chart 1 - Spend per tag - Top 10 tags by spend # Bar Chart 1 - Spend per tag - Top 10 tags by spend
total_spend_per_tag = collections.defaultdict(float) total_spend_per_tag: collections.defaultdict = collections.defaultdict(float)
total_requests_per_tag = collections.defaultdict(int) total_requests_per_tag: collections.defaultdict = collections.defaultdict(int)
for row in response: for row in response:
tag_name = row["individual_request_tag"] tag_name = row["individual_request_tag"]
tag_spend = row["total_spend"] tag_spend = row["total_spend"]
@ -49,15 +57,18 @@ async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=Non
# convert to ui format # convert to ui format
ui_tags = [] ui_tags = []
for tag in sorted_tags: for tag in sorted_tags:
current_spend = tag[1]
if current_spend is not None and isinstance(current_spend, float):
current_spend = round(current_spend, 4)
ui_tags.append( ui_tags.append(
{ {
"name": tag[0], "name": tag[0],
"value": tag[1], "spend": current_spend,
"log_count": total_requests_per_tag[tag[0]], "log_count": total_requests_per_tag[tag[0]],
} }
) )
return {"top_10_tags": ui_tags} return {"spend_per_tag": ui_tags}
async def view_spend_logs_from_clickhouse( async def view_spend_logs_from_clickhouse(

View file

@ -71,6 +71,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
clarifai_key: Optional[str] = None
maritalk_key: Optional[str] = None maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None ollama_key: Optional[str] = None
@ -101,6 +102,9 @@ blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
################## ##################
### PREVIEW FEATURES ###
enable_preview_features: bool = False
##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
@ -401,6 +405,73 @@ replicate_models: List = [
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
] ]
clarifai_models: List = [
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
"clarifai/gcp.generate.gemma-1_1-7b-it",
"clarifai/mistralai.completion.mixtral-8x22B",
"clarifai/cohere.generate.command-r-plus",
"clarifai/databricks.drbx.dbrx-instruct",
"clarifai/mistralai.completion.mistral-large",
"clarifai/mistralai.completion.mistral-medium",
"clarifai/mistralai.completion.mistral-small",
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
"clarifai/gcp.generate.gemma-2b-it",
"clarifai/gcp.generate.gemma-7b-it",
"clarifai/deci.decilm.deciLM-7B-instruct",
"clarifai/mistralai.completion.mistral-7B-Instruct",
"clarifai/gcp.generate.gemini-pro",
"clarifai/anthropic.completion.claude-v1",
"clarifai/anthropic.completion.claude-instant-1_2",
"clarifai/anthropic.completion.claude-instant",
"clarifai/anthropic.completion.claude-v2",
"clarifai/anthropic.completion.claude-2_1",
"clarifai/meta.Llama-2.codeLlama-70b-Python",
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
"clarifai/meta.Llama-2.llama2-7b-chat",
"clarifai/meta.Llama-2.llama2-13b-chat",
"clarifai/meta.Llama-2.llama2-70b-chat",
"clarifai/openai.chat-completion.gpt-4-turbo",
"clarifai/microsoft.text-generation.phi-2",
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
"clarifai/upstage.solar.solar-10_7b-instruct",
"clarifai/openchat.openchat.openchat-3_5-1210",
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
"clarifai/gcp.generate.text-bison",
"clarifai/meta.Llama-2.llamaGuard-7b",
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
"clarifai/openai.chat-completion.GPT-4",
"clarifai/openai.chat-completion.GPT-3_5-turbo",
"clarifai/ai21.complete.Jurassic2-Grande",
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo",
"clarifai/ai21.complete.Jurassic2-Large",
"clarifai/cohere.generate.cohere-generate-command",
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
"clarifai/wizardlm.generate.wizardLM-70B",
"clarifai/tiiuae.falcon.falcon-40b-instruct",
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
"clarifai/gcp.generate.code-gecko",
"clarifai/gcp.generate.code-bison",
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
"clarifai/wizardlm.generate.wizardLM-13B",
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
"clarifai/wizardlm.generate.wizardCoder-15B",
"clarifai/microsoft.text-generation.phi-1_5",
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
"clarifai/bigcode.code.StarCoder",
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
"clarifai/mosaicml.mpt.mpt-7b-instruct",
"clarifai/anthropic.completion.claude-3-opus",
"clarifai/anthropic.completion.claude-3-sonnet",
"clarifai/gcp.generate.gemini-1_5-pro",
"clarifai/gcp.generate.imagen-2",
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
]
huggingface_models: List = [ huggingface_models: List = [
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-7b-chat-hf",
@ -506,6 +577,7 @@ provider_list: List = [
"text-completion-openai", "text-completion-openai",
"cohere", "cohere",
"cohere_chat", "cohere_chat",
"clarifai",
"anthropic", "anthropic",
"replicate", "replicate",
"huggingface", "huggingface",
@ -656,6 +728,7 @@ from .llms.predibase import PredibaseConfig
from .llms.anthropic_text import AnthropicTextConfig from .llms.anthropic_text import AnthropicTextConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere import CohereConfig from .llms.cohere import CohereConfig
from .llms.clarifai import ClarifaiConfig
from .llms.ai21 import AI21Config from .llms.ai21 import AI21Config
from .llms.together_ai import TogetherAIConfig from .llms.together_ai import TogetherAIConfig
from .llms.cloudflare import CloudflareConfig from .llms.cloudflare import CloudflareConfig
@ -670,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,
@ -681,7 +755,7 @@ from .llms.bedrock import (
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig, AmazonBedrockGlobalConfig,
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
from .llms.watsonx import IBMWatsonXAIConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore

View file

@ -373,11 +373,12 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
) )
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided. # Set the value with a TTL if it's provided.
if ttl is not None: if ttl is not None:
pipe.setex(cache_key, ttl, json.dumps(cache_value)) pipe.setex(cache_key, ttl, json_cache_value)
else: else:
pipe.set(cache_key, json.dumps(cache_value)) pipe.set(cache_key, json_cache_value)
# Execute the pipeline and return the results. # Execute the pipeline and return the results.
results = await pipe.execute() results = await pipe.execute()
@ -810,9 +811,7 @@ class RedisSemanticCache(BaseCache):
# get the prompt # get the prompt
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# create an embedding for prompt # create an embedding for prompt
embedding_response = litellm.embedding( embedding_response = litellm.embedding(
@ -847,9 +846,7 @@ class RedisSemanticCache(BaseCache):
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# convert to embedding # convert to embedding
embedding_response = litellm.embedding( embedding_response = litellm.embedding(
@ -909,9 +906,7 @@ class RedisSemanticCache(BaseCache):
# get the prompt # get the prompt
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# create an embedding for prompt # create an embedding for prompt
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
@ -964,9 +959,7 @@ class RedisSemanticCache(BaseCache):
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]

View file

@ -9,25 +9,12 @@
## LiteLLM versions of the OpenAI Exception Types ## LiteLLM versions of the OpenAI Exception Types
from openai import ( import openai
AuthenticationError,
BadRequestError,
NotFoundError,
RateLimitError,
APIStatusError,
OpenAIError,
APIError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError,
PermissionDeniedError,
)
import httpx import httpx
from typing import Optional from typing import Optional
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401 self.status_code = 401
self.message = message self.message = message
@ -39,7 +26,7 @@ class AuthenticationError(AuthenticationError): # type: ignore
# raise when invalid models passed, example gpt-8 # raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore class NotFoundError(openai.NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 404 self.status_code = 404
self.message = message self.message = message
@ -50,7 +37,7 @@ class NotFoundError(NotFoundError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class BadRequestError(BadRequestError): # type: ignore class BadRequestError(openai.BadRequestError): # type: ignore
def __init__( def __init__(
self, message, model, llm_provider, response: Optional[httpx.Response] = None self, message, model, llm_provider, response: Optional[httpx.Response] = None
): ):
@ -69,7 +56,7 @@ class BadRequestError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422 self.status_code = 422
self.message = message self.message = message
@ -80,7 +67,7 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(openai.APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__( super().__init__(
@ -96,7 +83,7 @@ class Timeout(APITimeoutError): # type: ignore
return str(self.message) return str(self.message)
class PermissionDeniedError(PermissionDeniedError): # type:ignore class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 403 self.status_code = 403
self.message = message self.message = message
@ -107,7 +94,7 @@ class PermissionDeniedError(PermissionDeniedError): # type:ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(openai.RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429 self.status_code = 429
self.message = message self.message = message
@ -148,7 +135,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore class ServiceUnavailableError(openai.APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503 self.status_code = 503
self.message = message self.message = message
@ -160,7 +147,7 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore class APIError(openai.APIError): # type: ignore
def __init__( def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request self, status_code, message, llm_provider, model, request: httpx.Request
): ):
@ -172,7 +159,7 @@ class APIError(APIError): # type: ignore
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore class APIConnectionError(openai.APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model, request: httpx.Request): def __init__(self, message, llm_provider, model, request: httpx.Request):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
@ -182,7 +169,7 @@ class APIConnectionError(APIConnectionError): # type: ignore
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
@ -192,7 +179,7 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore
super().__init__(response=response, body=None, message=message) super().__init__(response=response, body=None, message=message)
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(openai.OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
self.status_code = original_exception.http_status self.status_code = original_exception.http_status
super().__init__( super().__init__(
@ -214,7 +201,7 @@ class BudgetExceededError(Exception):
## DEPRECATED ## ## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore class InvalidRequestError(openai.BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -3,7 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -18,8 +16,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union, Optional from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy import copy
import traceback import traceback
from packaging.version import Version from packaging.version import Version
@ -323,6 +321,9 @@ class LangFuseLogger:
trace_id = clean_metadata.pop("trace_id", None) trace_id = clean_metadata.pop("trace_id", None)
existing_trace_id = clean_metadata.pop("existing_trace_id", None) existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", []) update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
mask_input = clean_metadata.pop("mask_input", False)
mask_output = clean_metadata.pop("mask_output", False)
if trace_name is None and existing_trace_id is None: if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name # just log `litellm-{call_type}` as the trace name
@ -350,15 +351,15 @@ class LangFuseLogger:
# Special keys that are found in the function arguments and not the metadata # Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys: if "input" in update_trace_keys:
trace_params["input"] = input trace_params["input"] = input if not mask_input else "redacted-by-litellm"
if "output" in update_trace_keys: if "output" in update_trace_keys:
trace_params["output"] = output trace_params["output"] = output if not mask_output else "redacted-by-litellm"
else: # don't overwrite an existing trace else: # don't overwrite an existing trace
trace_params = { trace_params = {
"id": trace_id, "id": trace_id,
"name": trace_name, "name": trace_name,
"session_id": session_id, "session_id": session_id,
"input": input, "input": input if not mask_input else "redacted-by-litellm",
"version": clean_metadata.pop( "version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None) "trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
@ -374,7 +375,14 @@ class LangFuseLogger:
if level == "ERROR": if level == "ERROR":
trace_params["status_message"] = output trace_params["status_message"] = output
else: else:
trace_params["output"] = output trace_params["output"] = output if not mask_output else "redacted-by-litellm"
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:
# log the raw_metadata in the trace
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
else:
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
cost = kwargs.get("response_cost", None) cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}") print_verbose(f"trace: {cost}")
@ -426,7 +434,6 @@ class LangFuseLogger:
"url": url, "url": url,
"headers": clean_headers, "headers": clean_headers,
} }
trace = self.Langfuse.trace(**trace_params) trace = self.Langfuse.trace(**trace_params)
generation_id = None generation_id = None
@ -459,8 +466,8 @@ class LangFuseLogger:
"end_time": end_time, "end_time": end_time,
"model": kwargs["model"], "model": kwargs["model"],
"model_parameters": optional_params, "model_parameters": optional_params,
"input": input, "input": input if not mask_input else "redacted-by-litellm",
"output": output, "output": output if not mask_output else "redacted-by-litellm",
"usage": usage, "usage": usage,
"metadata": clean_metadata, "metadata": clean_metadata,
"level": level, "level": level,
@ -468,7 +475,29 @@ class LangFuseLogger:
} }
if supports_prompt: if supports_prompt:
generation_params["prompt"] = clean_metadata.pop("prompt", None) user_prompt = clean_metadata.pop("prompt", None)
if user_prompt is None:
pass
elif isinstance(user_prompt, dict):
from langfuse.model import (
TextPromptClient,
ChatPromptClient,
Prompt_Text,
Prompt_Chat,
)
if user_prompt.get("type", "") == "chat":
_prompt_chat = Prompt_Chat(**user_prompt)
generation_params["prompt"] = ChatPromptClient(
prompt=_prompt_chat
)
elif user_prompt.get("type", "") == "text":
_prompt_text = Prompt_Text(**user_prompt)
generation_params["prompt"] = TextPromptClient(
prompt=_prompt_text
)
else:
generation_params["prompt"] = user_prompt
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output generation_params["status_message"] = output

View file

@ -3,8 +3,6 @@
import dotenv, os # type: ignore import dotenv, os # type: ignore
import requests # type: ignore import requests # type: ignore
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio import asyncio
import types import types

View file

@ -2,13 +2,10 @@
# On success + failure, log events to lunary.ai # On success + failure, log events to lunary.ai
from datetime import datetime, timezone from datetime import datetime, timezone
import traceback import traceback
import dotenv
import importlib import importlib
import packaging import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
@ -79,14 +76,16 @@ class LunaryLogger:
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print( print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
) )
raise ImportError raise ImportError
self.lunary_client = lunary self.lunary_client = lunary
except ImportError: except ImportError:
print("Lunary not installed. Please install it using 'pip install lunary'") print( # noqa
"Lunary not installed. Please install it using 'pip install lunary'"
) # noqa
raise ImportError raise ImportError
def log_event( def log_event(

View file

@ -3,8 +3,6 @@
import dotenv, os, json import dotenv, os, json
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -4,8 +4,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -5,8 +5,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,9 +1,7 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -2,8 +2,6 @@
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict from typing import List, Literal, Any, Union, Optional, Dict
@ -33,7 +31,8 @@ class LiteLLMBase(BaseModel):
class SlackAlertingArgs(LiteLLMBase): class SlackAlertingArgs(LiteLLMBase):
daily_report_frequency: int = 12 * 60 * 60 # 12 hours default_daily_report_frequency: int = 12 * 60 * 60 # 12 hours
daily_report_frequency: int = int(os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency))
report_check_interval: int = 5 * 60 # 5 minutes report_check_interval: int = 5 * 60 # 5 minutes
@ -78,16 +77,14 @@ class SlackAlerting(CustomLogger):
internal_usage_cache: Optional[DualCache] = None, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: Optional[ alert_types: List[
List[ Literal[
Literal[ "llm_exceptions",
"llm_exceptions", "llm_too_slow",
"llm_too_slow", "llm_requests_hanging",
"llm_requests_hanging", "budget_alerts",
"budget_alerts", "db_exceptions",
"db_exceptions", "daily_reports",
"daily_reports",
]
] ]
] = [ ] = [
"llm_exceptions", "llm_exceptions",
@ -242,6 +239,8 @@ class SlackAlerting(CustomLogger):
end_time=end_time, end_time=end_time,
) )
) )
if litellm.turn_off_message_logging:
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
@ -464,6 +463,11 @@ class SlackAlerting(CustomLogger):
messages = messages[:100] messages = messages[:100]
except: except:
messages = "" messages = ""
if litellm.turn_off_message_logging:
messages = (
"Message not logged. `litellm.turn_off_message_logging=True`."
)
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
else: else:
request_info = "" request_info = ""
@ -814,14 +818,6 @@ Model Info:
updated_at=litellm.utils.get_utc_datetime(), updated_at=litellm.utils.get_utc_datetime(),
) )
) )
if "llm_exceptions" in self.alert_types:
original_exception = kwargs.get("exception", None)
await self.send_alert(
message="LLM API Failure - " + str(original_exception),
level="High",
alert_type="llm_exceptions",
)
async def _run_scheduler_helper(self, llm_router) -> bool: async def _run_scheduler_helper(self, llm_router) -> bool:
""" """
@ -885,3 +881,99 @@ Model Info:
) # shuffle to prevent collisions ) # shuffle to prevent collisions
await asyncio.sleep(interval) await asyncio.sleep(interval)
return return
async def send_weekly_spend_report(self):
""" """
try:
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
todays_date = datetime.datetime.now().date()
week_before = todays_date - datetime.timedelta(days=7)
weekly_spend_per_team, weekly_spend_per_tag = (
await _get_spend_report_for_time_range(
start_date=week_before.strftime("%Y-%m-%d"),
end_date=todays_date.strftime("%Y-%m-%d"),
)
)
_weekly_spend_message = f"*💸 Weekly Spend Report for `{week_before.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` *\n"
if weekly_spend_per_team is not None:
_weekly_spend_message += "\n*Team Spend Report:*\n"
for spend in weekly_spend_per_team:
_team_spend = spend["total_spend"]
_team_spend = float(_team_spend)
# round to 4 decimal places
_team_spend = round(_team_spend, 4)
_weekly_spend_message += (
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
)
if weekly_spend_per_tag is not None:
_weekly_spend_message += "\n*Tag Spend Report:*\n"
for spend in weekly_spend_per_tag:
_tag_spend = spend["total_spend"]
_tag_spend = float(_tag_spend)
# round to 4 decimal places
_tag_spend = round(_tag_spend, 4)
_weekly_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
await self.send_alert(
message=_weekly_spend_message,
level="Low",
alert_type="daily_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)
async def send_monthly_spend_report(self):
""" """
try:
from calendar import monthrange
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
todays_date = datetime.datetime.now().date()
first_day_of_month = todays_date.replace(day=1)
_, last_day_of_month = monthrange(todays_date.year, todays_date.month)
last_day_of_month = first_day_of_month + datetime.timedelta(
days=last_day_of_month - 1
)
monthly_spend_per_team, monthly_spend_per_tag = (
await _get_spend_report_for_time_range(
start_date=first_day_of_month.strftime("%Y-%m-%d"),
end_date=last_day_of_month.strftime("%Y-%m-%d"),
)
)
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
if monthly_spend_per_team is not None:
_spend_message += "\n*Team Spend Report:*\n"
for spend in monthly_spend_per_team:
_team_spend = spend["total_spend"]
_team_spend = float(_team_spend)
# round to 4 decimal places
_team_spend = round(_team_spend, 4)
_spend_message += (
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
)
if monthly_spend_per_tag is not None:
_spend_message += "\n*Tag Spend Report:*\n"
for spend in monthly_spend_per_tag:
_tag_spend = spend["total_spend"]
_tag_spend = float(_tag_spend)
# round to 4 decimal places
_tag_spend = round(_tag_spend, 4)
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
await self.send_alert(
message=_spend_message,
level="Low",
alert_type="daily_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm import litellm

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: def __getitem__(self, key: K) -> V: ... # noqa
... # pragma: no cover
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: def get( # noqa
... # pragma: no cover self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(
@ -173,12 +173,11 @@ except:
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
import requests import requests
import requests import requests
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_streaming_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
"""
Return stream object for tool-calling + streaming
"""
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response( def process_response(
self, self,
model, model: str,
response, response: Union[requests.Response, httpx.Response],
model_response, model_response: ModelResponse,
_is_function_call, stream: bool,
stream, logging_obj: litellm.utils.Logging,
logging_obj, optional_params: dict,
api_key, api_key: str,
data, data: Union[dict, str],
messages, messages: List,
print_verbose, print_verbose,
): encoding,
) -> ModelResponse:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage) # type: ignore
return model_response return model_response
async def acompletion_stream_function( async def acompletion_stream_function(
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
): ) -> Union[ModelResponse, CustomStreamWrapper]:
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
response = await self.async_handler.post( response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def completion( def completion(
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def embedding(self): def embedding(self):

View file

@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_response( def _process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str self, model_response: ModelResponse, response, encoding, prompt: str, model: str
): ):
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
response = self.process_response( response = self._process_response(
model_response=model_response, model_response=model_response,
response=response, response=response,
encoding=encoding, encoding=encoding,
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
response = self.process_response( response = self._process_response(
model_response=model_response, model_response=model_response,
response=response, response=response,
encoding=encoding, encoding=encoding,

View file

@ -10,7 +10,7 @@ from litellm.utils import (
TranscriptionResponse, TranscriptionResponse,
get_secret, get_secret,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx # type: ignore import httpx # type: ignore
@ -107,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
optional_params["azure_ad_token"] = value optional_params["azure_ad_token"] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return ["europe", "sweden", "switzerland", "france", "uk"]
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {

View file

@ -1,12 +1,32 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import httpx import httpx, requests
from typing import Optional from typing import Optional, Union
from litellm.utils import Logging
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
"""
return [
"eu-west-1",
"eu-west-3",
"eu-central-1",
]
class AmazonTitanConfig: class AmazonTitanConfig:
""" """

View file

@ -0,0 +1,733 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import (
Callable,
Optional,
List,
Literal,
Union,
Any,
TypedDict,
Tuple,
Iterator,
AsyncIterator,
)
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt
from litellm.types.llms.bedrock import *
class AmazonCohereChatConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
"""
documents: Optional[List[Document]] = None
search_queries_only: Optional[bool] = None
preamble: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
p: Optional[float] = None
k: Optional[float] = None
prompt_truncation: Optional[str] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
return_prompt: Optional[bool] = None
stop_sequences: Optional[List[str]] = None
raw_prompting: Optional[bool] = None
def __init__(
self,
documents: Optional[List[Document]] = None,
search_queries_only: Optional[bool] = None,
preamble: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
p: Optional[float] = None,
k: Optional[float] = None,
prompt_truncation: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
return_prompt: Optional[bool] = None,
stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return [
"max_tokens",
"stream",
"stop",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"seed",
"stop",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if "seed":
optional_params["seed"] = value
return optional_params
class BedrockLLM(BaseLLM):
"""
Example call
```
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
--header 'Content-Type: application/json' \
--header 'Accept: application/json' \
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
--data-raw '{
"prompt": "Hi",
"temperature": 0,
"p": 0.9,
"max_tokens": 4096
}'
```
"""
def __init__(self) -> None:
super().__init__()
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
) = params_to_check
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
return sts_response["Credentials"]
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if stream is not None and stream == True:
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
data = json.dumps(_data)
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream == True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
if stream is not None and stream == True:
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
stream=stream,
)
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=response.text
)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
loader = Loader()
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")
class AWSEventStreamDecoder:
def __init__(self) -> None:
from botocore.parsers import EventStreamJSONParser
self.parser = EventStreamJSONParser()
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GenericStreamingChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]

328
litellm/llms/clarifai.py Normal file
View file

@ -0,0 +1,328 @@
import os, types, traceback
import json
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, Choices, Message, CustomStreamWrapper
import litellm
import httpx
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .prompt_templates.factory import prompt_factory, custom_prompt
class ClarifaiError(Exception):
def __init__(self, status_code, message, url):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
)
class ClarifaiConfig:
"""
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
TODO fill in the details
"""
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completions_to_model(payload):
# if payload["n"] != 1:
# raise HTTPException(
# status_code=422,
# detail="Only one generation is supported. Please set candidate_count to 1.",
# )
params = {}
if temperature := payload.get("temperature"):
params["temperature"] = temperature
if max_tokens := payload.get("max_tokens"):
params["max_tokens"] = max_tokens
return {
"inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
"model": {"output_info": {"params": params}},
}
def process_response(
model,
prompt,
response,
model_response,
api_key,
data,
encoding,
logging_obj
):
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
)
# print(completion_response)
try:
choices_list = []
for idx, item in enumerate(completion_response["outputs"]):
if len(item["data"]["text"]["raw"]) > 0:
message_obj = Message(content=item["data"]["text"]["raw"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
index=idx + 1, #check
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
# Calculate Usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response["model"] = model
model_response["usage"] = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
def get_prompt_model_name(url: str):
clarifai_model_name = url.split("/")[-2]
if "claude" in clarifai_model_name:
return "anthropic", clarifai_model_name.replace("_", ".")
if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
return "", "meta-llama/llama-2-chat"
else:
return "", clarifai_model_name
async def async_completion(
model: str,
prompt: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={}):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj,
)
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
custom_prompt_dict={},
acompletion=False,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
model = convert_model_to_url(model, api_base)
prompt = " ".join(message["content"] for message in messages) # TODO
## Load Config
config = litellm.ClarifaiConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
):
optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
if custom_llm_provider == "anthropic":
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider="clarifai"
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider=custom_llm_provider
)
# print(prompt); exit(0)
data = {
"prompt": prompt,
**optional_params,
}
data = completions_to_model(data)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
if acompletion==True:
return async_completion(
model=model,
prompt=prompt,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
data=data,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
response = requests.post(
model,
headers=headers,
data=json.dumps(data),
)
# print(response.content); exit()
if response.status_code != 200:
raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
if "stream" in optional_params and optional_params["stream"] == True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
)
return stream_response
else:
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj)
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response

View file

@ -58,16 +58,25 @@ class AsyncHTTPHandler:
class HTTPHandler: class HTTPHandler:
def __init__( def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 self,
timeout: Optional[httpx.Timeout] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
): ):
# Create a client with a connection pool if timeout is None:
self.client = httpx.Client( timeout = _DEFAULT_TIMEOUT
timeout=timeout,
limits=httpx.Limits( if client is None:
max_connections=concurrent_limit, # Create a client with a connection pool
max_keepalive_connections=concurrent_limit, self.client = httpx.Client(
), timeout=timeout,
) limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
else:
self.client = client
def close(self): def close(self):
# Close the client when you're done with it # Close the client when you're done with it
@ -82,11 +91,15 @@ class HTTPHandler:
def post( def post(
self, self,
url: str, url: str,
data: Optional[dict] = None, data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False,
): ):
response = self.client.post(url, data=data, params=params, headers=headers) req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response return response
def __del__(self) -> None: def __del__(self) -> None:

View file

@ -300,7 +300,7 @@ def get_ollama_response(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
@ -484,7 +484,7 @@ async def ollama_acompletion(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"] model_response["model"] = "ollama_chat/" + data["model"]

View file

@ -53,6 +53,113 @@ class OpenAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
return optional_params
class OpenAIConfig: class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -1327,8 +1434,8 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data thread_id, **message_data # type: ignore
) )
response_obj: Optional[OpenAIMessage] = None response_obj: Optional[OpenAIMessage] = None
@ -1458,7 +1565,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
response = openai_client.beta.threads.runs.create_and_poll( response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,

View file

@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj: litellm.utils.Logging, logging_obj: litellm.utils.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: dict, data: Union[dict, str],
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
try: try:
completion_response = response.json() completion_response = response.json()
except: except:
raise PredibaseError( raise PredibaseError(message=response.text, status_code=422)
message=response.text, status_code=response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise PredibaseError( raise PredibaseError(
message=str(completion_response["error"]), message=str(completion_response["error"]),
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True: if acompletion == True:
### ASYNC STREAMING ### ASYNC STREAMING
if stream == True: if stream == True:
return self.async_streaming( return self.async_streaming(

View file

@ -1509,6 +1509,11 @@ def prompt_factory(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=messages, messages=messages,
) )
elif custom_llm_provider == "clarifai":
if "claude" in model:
return anthropic_pt(messages=messages)
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "perplexity":
for message in messages: for message in messages:
message.pop("name", None) message.pop("name", None)

View file

@ -198,6 +198,23 @@ class VertexAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
"""
return [
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
]
import asyncio import asyncio
@ -850,6 +867,8 @@ async def async_completion(
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
try: try:
import proto # type: ignore
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
@ -884,9 +903,21 @@ async def async_completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
args_str = json.dumps(args_dict) for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[

View file

@ -1,12 +1,26 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import (
Callable,
Dict,
Generator,
AsyncGenerator,
Iterator,
AsyncIterator,
Optional,
Any,
Union,
List,
ContextManager,
AsyncContextManager,
)
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -149,6 +163,15 @@ class IBMWatsonXAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"eu-de",
"eu-gb",
]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
@ -188,11 +211,12 @@ class WatsonXAIEndpoint(str, Enum):
) )
EMBEDDINGS = "/ml/v1/text/embeddings" EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts" PROMPTS = "/ml/v1/prompts"
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM Watsonx.ai API for text generation and embeddings. Class to interface with IBM watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai Reference: https://cloud.ibm.com/apidocs/watsonx-ai
""" """
@ -343,7 +367,7 @@ class IBMWatsonXAI(BaseLLM):
) )
if token is None and api_key is not None: if token is None and api_key is not None:
# generate the auth token # generate the auth token
if print_verbose: if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai") print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key) token = self.generate_iam_token(api_key)
elif token is None and api_key is None: elif token is None and api_key is None:
@ -378,10 +402,11 @@ class IBMWatsonXAI(BaseLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict, optional_params=None,
litellm_params: Optional[dict] = None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
timeout: Optional[float] = None, timeout=None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -402,12 +427,12 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
def process_text_request(request_params: dict) -> ModelResponse: def process_text_gen_response(json_resp: dict) -> ModelResponse:
with self._manage_response( if "results" not in json_resp:
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout raise WatsonXAIError(
) as resp: status_code=500,
json_resp = resp.json() message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"] generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"] prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -415,36 +440,70 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
setattr( usage = Usage(
model_response, prompt_tokens=prompt_tokens,
"usage", completion_tokens=completion_tokens,
Usage( total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
setattr(model_response, "usage", usage)
return model_response return model_response
def process_stream_request( def process_stream_response(
request_params: dict, stream_resp: Union[Iterator[str], AsyncIterator],
) -> litellm.CustomStreamWrapper: ) -> litellm.CustomStreamWrapper:
streamwrapper = litellm.CustomStreamWrapper(
stream_resp,
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
# create the function to manage the request to watsonx.ai
self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse:
with self.request_manager.request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with self.request_manager.async_request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response( with self.request_manager.request(
request_params, request_params,
logging_obj=logging_obj,
stream=True, stream=True,
input=prompt, input=prompt,
timeout=timeout, timeout=timeout,
) as resp: ) as resp:
response = litellm.CustomStreamWrapper( streamwrapper = process_stream_response(resp.iter_lines())
resp.iter_lines(), return streamwrapper
model=model,
custom_llm_provider="watsonx", async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
logging_obj=logging_obj, # stream the response - generated chunks will be handled
) # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
return response async with self.request_manager.async_request(
request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = process_stream_response(resp.aiter_lines())
return streamwrapper
try: try:
## Get the response from the model ## Get the response from the model
@ -455,10 +514,18 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream: if stream and (acompletion is True):
return process_stream_request(req_params) # stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif (acompletion is True):
# async text generation
return handle_text_request_async(req_params)
else: else:
return process_text_request(req_params) # regular text generation
return handle_text_request(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -473,6 +540,7 @@ class IBMWatsonXAI(BaseLLM):
model_response=None, model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
aembedding=None,
): ):
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
@ -507,9 +575,6 @@ class IBMWatsonXAI(BaseLLM):
} }
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = { req_params = {
"method": "POST", "method": "POST",
"url": url, "url": url,
@ -517,25 +582,49 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
with self._manage_response( request_manager = RequestManager(logging_obj)
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
results = json_resp.get("results", []) def process_embedding_response(json_resp: dict) -> ModelResponse:
embedding_response = [] results = json_resp.get("results", [])
for idx, result in enumerate(results): embedding_response = []
embedding_response.append( for idx, result in enumerate(results):
{"object": "embedding", "index": idx, "embedding": result["embedding"]} embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
) )
model_response["object"] = "list" return model_response
model_response["data"] = embedding_response
model_response["model"] = model def handle_embedding(request_params: dict) -> ModelResponse:
input_tokens = json_resp.get("input_token_count", 0) with request_manager.request(request_params, input=input) as resp:
model_response.usage = Usage( json_resp = resp.json()
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens return process_embedding_response(json_resp)
)
return model_response async def handle_aembedding(request_params: dict) -> ModelResponse:
async with request_manager.async_request(request_params, input=input) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding is True:
return handle_embedding(req_params)
else:
return handle_aembedding(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params): def generate_iam_token(self, api_key=None, **params):
headers = {} headers = {}
@ -558,52 +647,144 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token self.token = iam_access_token
return iam_access_token return iam_access_token
@contextmanager def get_available_models(self, *, ids_only: bool = True, **params):
def _manage_response( api_params = self._get_api_params(params)
self, headers = {
request_params: dict, "Authorization": f"Bearer {api_params['token']}",
logging_obj: Any, "Content-Type": "application/json",
stream: bool = False, "Accept": "application/json",
input: Optional[Any] = None, }
timeout: Optional[float] = None, request_params = dict(version=api_params["api_version"])
): url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
request_str = ( req_params = dict(method="GET", url=url, headers=headers, params=request_params)
f"response = {request_params['method']}(\n" with RequestManager(logging_obj=None).request(req_params) as resp:
f"\turl={request_params['url']},\n" json_resp = resp.json()
f"\tjson={request_params['json']},\n" if not ids_only:
f")" return json_resp
) return [res["model_id"] for res in json_resp["resources"]]
logging_obj.pre_call(
input=input, class RequestManager:
api_key=request_params["headers"].get("Authorization"), """
additional_args={ Returns a context manager that manages the response from the request.
"complete_input_dict": request_params["json"], if async_ is True, returns an async context manager, otherwise returns a regular context manager.
"request_str": request_str,
}, Usage:
) ```python
if timeout: request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_params["timeout"] = timeout request_manager = RequestManager(logging_obj=logging_obj)
try: async with request_manager.request(request_params) as resp:
if stream: ...
resp = requests.request( # or
**request_params, with request_manager.async_request(request_params) as resp:
stream=True, ...
) ```
resp.raise_for_status() """
yield resp
else: def __init__(self, logging_obj=None):
resp = requests.request(**request_params) self.logging_obj = logging_obj
resp.raise_for_status()
yield resp def pre_call(
except Exception as e: self,
raise WatsonXAIError(status_code=500, message=str(e)) request_params: dict,
if not stream: input: Optional[Any] = None,
logging_obj.post_call( ):
if self.logging_obj is None:
return
request_str = (
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request_params["json"], "complete_input_dict": request_params.get(
"data", request_params.get("json")
),
}, },
) )
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)

View file

@ -9,6 +9,7 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
@ -56,6 +57,7 @@ from .llms import (
ollama, ollama,
ollama_chat, ollama_chat,
cloudflare, cloudflare,
clarifai,
cohere, cohere,
cohere_chat, cohere_chat,
petals, petals,
@ -75,6 +77,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -104,7 +107,6 @@ from litellm.utils import (
) )
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
@ -114,6 +116,7 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -256,7 +259,7 @@ async def acompletion(
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
custom_llm_provider = None custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs # Adjusted to use explicit arguments instead of *args and **kwargs
completion_kwargs = { completion_kwargs = {
"model": model, "model": model,
@ -288,9 +291,10 @@ async def acompletion(
"model_list": model_list, "model_list": model_list,
"acompletion": True, # assuming this is a required parameter "acompletion": True, # assuming this is a required parameter
} }
_, custom_llm_provider, _, _ = get_llm_provider( if custom_llm_provider is None:
model=model, api_base=completion_kwargs.get("base_url", None) _, custom_llm_provider, _, _ = get_llm_provider(
) model=model, api_base=completion_kwargs.get("base_url", None)
)
try: try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs) func = partial(completion, **completion_kwargs, **kwargs)
@ -299,9 +303,6 @@ async def acompletion(
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
@ -323,6 +324,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model)
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -725,7 +727,6 @@ def completion(
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None: if input_cost_per_token is not None and output_cost_per_token is not None:
print_verbose(f"Registering model={model} in model cost map")
litellm.register_model( litellm.register_model(
{ {
f"{custom_llm_provider}/{model}": { f"{custom_llm_provider}/{model}": {
@ -847,6 +848,10 @@ def completion(
proxy_server_request=proxy_server_request, proxy_server_request=proxy_server_request,
preset_cache_key=preset_cache_key, preset_cache_key=preset_cache_key,
no_log=no_log, no_log=no_log,
input_cost_per_second=input_cost_per_second,
input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -1212,6 +1217,61 @@ def completion(
) )
response = model_response response = model_response
elif (
"clarifai" in model
or custom_llm_provider == "clarifai"
or model in litellm.clarifai_models
):
clarifai_key = None
clarifai_key = (
api_key
or litellm.clarifai_key
or litellm.api_key
or get_secret("CLARIFAI_API_KEY")
or get_secret("CLARIFAI_API_TOKEN")
)
api_base = (
api_base
or litellm.api_base
or get_secret("CLARIFAI_API_BASE")
or "https://api.clarifai.com/v2"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = clarifai.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
acompletion=acompletion,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=clarifai_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=clarifai_key,
original_response=model_response,
)
response = model_response
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
api_key = ( api_key = (
@ -1921,41 +1981,59 @@ def completion(
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if ( if "cohere" in model:
"stream" in optional_params response = bedrock_chat_completion.completion(
and optional_params["stream"] == True model=model,
and not isinstance(response, CustomStreamWrapper) messages=messages,
): custom_prompt_dict=litellm.custom_prompt_dict,
# don't try to access stream object, model_response=model_response,
if "ai21" in model: print_verbose=print_verbose,
response = CustomStreamWrapper( optional_params=optional_params,
response, litellm_params=litellm_params,
model, logger_fn=logger_fn,
custom_llm_provider="bedrock", encoding=encoding,
logging_obj=logging, logging_obj=logging,
) extra_headers=extra_headers,
else: timeout=timeout,
response = CustomStreamWrapper( acompletion=acompletion,
iter(response), )
model, else:
custom_llm_provider="bedrock", response = bedrock.completion(
logging_obj=logging, model=model,
) messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if optional_params.get("stream", False): if optional_params.get("stream", False):
## LOGGING ## LOGGING

View file

@ -9,6 +9,30 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true
}, },
"gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-turbo-preview": { "gpt-4-turbo-preview": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -1571,6 +1595,135 @@
"litellm_provider": "replicate", "litellm_provider": "replicate",
"mode": "chat" "mode": "chat"
}, },
"openrouter/microsoft/wizardlm-2-8x22b:nitro": {
"max_tokens": 65536,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-1.5": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.0000075,
"input_cost_per_image": 0.00265,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mixtral-8x22b-instruct": {
"max_tokens": 65536,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000065,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cohere/command-r-plus": {
"max_tokens": 128000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/databricks/dbrx-instruct": {
"max_tokens": 32768,
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/anthropic/claude-3-haiku": {
"max_tokens": 200000,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"input_cost_per_image": 0.0004,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/anthropic/claude-3-sonnet": {
"max_tokens": 200000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"input_cost_per_image": 0.0048,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mistral-large": {
"max_tokens": 32000,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cognitivecomputations/dolphin-mixtral-8x7b": {
"max_tokens": 32769,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-vision": {
"max_tokens": 45875,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000375,
"input_cost_per_image": 0.0025,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/fireworks/firellava-13b": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:extended": {
"max_tokens": 16384,
"input_cost_per_token": 0.000000225,
"output_cost_per_token": 0.00000225,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct:nitro": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/openai/gpt-4-vision-preview": {
"max_tokens": 130000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"input_cost_per_image": 0.01445,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-3.5-turbo": { "openrouter/openai/gpt-3.5-turbo": {
"max_tokens": 4095, "max_tokens": 4095,
"input_cost_per_token": 0.0000015, "input_cost_per_token": 0.0000015,
@ -1621,14 +1774,14 @@
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"openrouter/google/palm-2-chat-bison": { "openrouter/google/palm-2-chat-bison": {
"max_tokens": 8000, "max_tokens": 25804,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/google/palm-2-codechat-bison": { "openrouter/google/palm-2-codechat-bison": {
"max_tokens": 8000, "max_tokens": 20070,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
@ -1711,13 +1864,6 @@
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "openrouter",
"mode": "chat"
},
"j2-ultra": { "j2-ultra": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,
@ -2522,6 +2668,24 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"cohere.command-r-plus-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000030,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"cohere.command-r-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"cohere.embed-english-v3": { "cohere.embed-english-v3": {
"max_tokens": 512, "max_tokens": 512,
"max_input_tokens": 512, "max_input_tokens": 512,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/a1602eb39f799143.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}(); !function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/f04e46b02318b660.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/a1602eb39f799143.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[25539,[\"936\",\"static/chunks/2f6dbc85-17d29013b8ff3da5.js\",\"566\",\"static/chunks/566-ccd699ab19124658.js\",\"931\",\"static/chunks/app/page-c804e862b63be987.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/a1602eb39f799143.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"K8KXTbmuI2ArWjjdMi2iq\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html> <!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[7926,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-6a39771cacf75ea6.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"obp5wqVSVDMiDTC414cR8\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""] 2:I[77831,[],""]
3:I[25539,["936","static/chunks/2f6dbc85-17d29013b8ff3da5.js","566","static/chunks/566-ccd699ab19124658.js","931","static/chunks/app/page-c804e862b63be987.js"],""] 3:I[7926,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-6a39771cacf75ea6.js"],""]
4:I[5613,[],""] 4:I[5613,[],""]
5:I[31778,[],""] 5:I[31778,[],""]
0:["K8KXTbmuI2ArWjjdMi2iq",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/a1602eb39f799143.css","precedence":"next","crossOrigin":""}]],"$L6"]]]] 0:["obp5wqVSVDMiDTC414cR8",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]] 6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null 1:null

View file

@ -1,33 +1,35 @@
model_list: model_list:
- litellm_params: - litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ api_base: os.environ/AZURE_API_BASE
api_key: my-fake-key api_key: os.environ/AZURE_API_KEY
model: openai/my-fake-model api_version: 2023-07-01-preview
model_name: fake-openai-endpoint model: azure/azure-embedding-model
- litellm_params: model_info:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ base_model: text-embedding-ada-002
api_key: my-fake-key-2 mode: embedding
model: openai/my-fake-model-2 model_name: text-embedding-ada-002
model_name: fake-openai-endpoint - model_name: gpt-3.5-turbo-012
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key-3
model: openai/my-fake-model-3
model_name: fake-openai-endpoint
- model_name: gpt-4
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
- litellm_params: api_base: http://0.0.0.0:8080
model: together_ai/codellama/CodeLlama-13b-Instruct-hf api_key: ""
model_name: CodeLlama-13b-Instruct - model_name: gpt-3.5-turbo-0125-preview
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
input_cost_per_token: 0.0
output_cost_per_token: 0.0
router_settings: router_settings:
redis_host: redis redis_host: redis
# redis_password: <your redis password> # redis_password: <your redis password>
redis_port: 6379 redis_port: 6379
enable_pre_call_checks: true
litellm_settings: litellm_settings:
set_verbose: True set_verbose: True
fallbacks: [{"gpt-3.5-turbo-012": ["gpt-3.5-turbo-0125-preview"]}]
# service_callback: ["prometheus_system"] # service_callback: ["prometheus_system"]
# success_callback: ["prometheus"] # success_callback: ["prometheus"]
# failure_callback: ["prometheus"] # failure_callback: ["prometheus"]
@ -36,4 +38,5 @@ general_settings:
enable_jwt_auth: True enable_jwt_auth: True
disable_reset_budget: True disable_reset_budget: True
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle" routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
alerting: ["slack"]

View file

@ -1,11 +1,20 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator from pydantic import ConfigDict, BaseModel, Field, root_validator, Json
from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
import uuid, json, sys, os import uuid
import json
from litellm.types.router import UpdateRouterConfig from litellm.types.router import UpdateRouterConfig
try:
from pydantic import model_validator # pydantic v2
except ImportError:
from pydantic import root_validator # pydantic v1
def model_validator(mode):
pre = mode == "before"
return root_validator(pre=pre)
def hash_token(token: str): def hash_token(token: str):
import hashlib import hashlib
@ -35,8 +44,9 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.__fields_set__ return self.__fields_set__
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -82,6 +92,7 @@ class LiteLLMRoutes(enum.Enum):
info_routes: List = [ info_routes: List = [
"/key/info", "/key/info",
"/team/info", "/team/info",
"/team/list",
"/user/info", "/user/info",
"/model/info", "/model/info",
"/v2/model/info", "/v2/model/info",
@ -110,6 +121,7 @@ class LiteLLMRoutes(enum.Enum):
"/team/new", "/team/new",
"/team/update", "/team/update",
"/team/delete", "/team/delete",
"/team/list",
"/team/info", "/team/info",
"/team/block", "/team/block",
"/team/unblock", "/team/unblock",
@ -182,8 +194,19 @@ class LiteLLM_JWTAuth(LiteLLMBase):
admin_jwt_scope: str = "litellm_proxy_admin" admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[ admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"] Literal[
] = ["management_routes"] "openai_routes",
"info_routes",
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
]
] = [
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
"info_routes",
]
team_jwt_scope: str = "litellm_team" team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id" team_id_jwt_field: str = "client_id"
team_allowed_routes: List[ team_allowed_routes: List[
@ -216,7 +239,7 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_system_prompt: Optional[str] = None llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None llm_api_fail_call_string: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_llm_api_params(cls, values): def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check") llm_api_check = values.get("llm_api_check")
if llm_api_check is True: if llm_api_check is True:
@ -274,8 +297,9 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None deployment_id: Optional[str] = None
request_timeout: Optional[int] = None request_timeout: Optional[int] = None
class Config: model_config = ConfigDict(
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class ModelInfoDelete(LiteLLMBase): class ModelInfoDelete(LiteLLMBase):
@ -302,11 +326,12 @@ class ModelInfo(LiteLLMBase):
] ]
] ]
class Config: model_config = ConfigDict(
extra = Extra.allow # Allow extra fields extra = "allow", # Allow extra fields
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("id") is None: if values.get("id") is None:
values.update({"id": str(uuid.uuid4())}) values.update({"id": str(uuid.uuid4())})
@ -332,10 +357,11 @@ class ModelParams(LiteLLMBase):
litellm_params: dict litellm_params: dict
model_info: ModelInfo model_info: ModelInfo
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("model_info") is None: if values.get("model_info") is None:
values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
@ -371,8 +397,9 @@ class GenerateKeyRequest(GenerateRequestBase):
{} {}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class GenerateKeyResponse(GenerateKeyRequest): class GenerateKeyResponse(GenerateKeyRequest):
@ -382,7 +409,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None user_id: Optional[str] = None
token_id: Optional[str] = None token_id: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("token") is not None: if values.get("token") is not None:
values.update({"key": values.get("token")}) values.update({"key": values.get("token")})
@ -422,8 +449,9 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str created_by: str
updated_by: str updated_by: str
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
@ -451,7 +479,7 @@ class UpdateUserRequest(GenerateRequestBase):
user_role: Optional[str] = None user_role: Optional[str] = None
max_budget: Optional[float] = None max_budget: Optional[float] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -471,7 +499,7 @@ class NewEndUserRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model None # if no equivalent model in allowed region - default all requests to this model
) )
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None: if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.") raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -484,7 +512,7 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -509,8 +537,9 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase): class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None model_aliases: Optional[dict] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class GlobalEndUsersSpend(LiteLLMBase): class GlobalEndUsersSpend(LiteLLMBase):
@ -529,7 +558,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -563,10 +592,11 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
dict_fields = [ dict_fields = [
"metadata", "metadata",
@ -602,8 +632,9 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class NewOrganizationRequest(LiteLLM_BudgetTable): class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -653,8 +684,9 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase):
team_id: str team_id: str
class Config: model_config = ConfigDict(
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class DynamoDBArgs(LiteLLMBase): class DynamoDBArgs(LiteLLMBase):
@ -795,8 +827,9 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5", description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
) )
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_VerificationToken(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase):
@ -830,8 +863,9 @@ class LiteLLM_VerificationToken(LiteLLMBase):
user_id_rate_limits: Optional[dict] = None user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None team_id_rate_limits: Optional[dict] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
@ -861,7 +895,7 @@ class UserAPIKeyAuth(
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_api_key(cls, values): def check_api_key(cls, values):
if values.get("api_key") is not None: if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))}) values.update({"token": hash_token(values.get("api_key"))})
@ -888,7 +922,7 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
@ -896,8 +930,9 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []}) values.update({"models": []})
return values return values
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_EndUserTable(LiteLLMBase): class LiteLLM_EndUserTable(LiteLLMBase):
@ -909,14 +944,15 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
return values return values
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_SpendLogs(LiteLLMBase): class LiteLLM_SpendLogs(LiteLLMBase):

View file

@ -1,10 +1,7 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv
import os import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:

View file

@ -0,0 +1,147 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from typing import Optional
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions,
AnalyzeTextOutputType,
)
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error as e:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
traceback.print_exc()
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content, source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")

View file

@ -4,6 +4,12 @@ model_list:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: llama3
litellm_params:
model: groq/llama3-8b-8192
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: openai/* model: openai/*

View file

@ -425,7 +425,7 @@ async def user_api_key_auth(
litellm_proxy_roles=jwt_handler.litellm_jwtauth, litellm_proxy_roles=jwt_handler.litellm_jwtauth,
) )
if is_allowed == False: if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes) actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception( raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
@ -2255,6 +2255,31 @@ class ProxyConfig:
batch_redis_obj = _PROXY_BatchRedisRequests() batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj) imported_list.append(batch_redis_obj)
elif (
isinstance(callback, str)
and callback == "azure_content_safety"
):
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = (
litellm.get_secret(v)
)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else: else:
imported_list.append( imported_list.append(
get_instance_fn( get_instance_fn(
@ -3454,6 +3479,26 @@ async def startup_event():
await proxy_config.add_deployment( await proxy_config.add_deployment(
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
) )
if (
proxy_logging_obj is not None
and proxy_logging_obj.slack_alerting_instance is not None
and prisma_client is not None
):
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
### Schedule weekly/monhtly spend reports ###
scheduler.add_job(
proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report,
"cron",
day_of_week="mon",
)
scheduler.add_job(
proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report,
"cron",
day=1,
)
scheduler.start() scheduler.start()
@ -3639,7 +3684,7 @@ async def chat_completion(
### MODEL ALIAS MAPPING ### ### MODEL ALIAS MAPPING ###
# check if model name in model alias map # check if model name in model alias map
# get the actual model name # get the actual model name
if data["model"] in litellm.model_alias_map: if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]] data["model"] = litellm.model_alias_map[data["model"]]
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call ## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
@ -3673,6 +3718,10 @@ async def chat_completion(
# skip router if user passed their key # skip router if user passed their key
if "api_key" in data: if "api_key" in data:
tasks.append(litellm.acompletion(**data)) tasks.append(litellm.acompletion(**data))
elif "," in data["model"] and llm_router is not None:
_models_csv_string = data.pop("model")
_models = _models_csv_string.split(",")
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data: elif "user_config" in data:
# initialize a new router instance. make request using this Router # initialize a new router instance. make request using this Router
router_config = data.pop("user_config") router_config = data.pop("user_config")
@ -3733,6 +3782,7 @@ async def chat_completion(
"x-litellm-cache-key": cache_key, "x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base, "x-litellm-model-api-base": api_base,
"x-litellm-version": version, "x-litellm-version": version,
"x-litellm-model-region": user_api_key_dict.allowed_model_region or "",
} }
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -3749,6 +3799,9 @@ async def chat_completion(
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version fastapi_response.headers["x-litellm-version"] = version
fastapi_response.headers["x-litellm-model-region"] = (
user_api_key_dict.allowed_model_region or ""
)
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
@ -4133,6 +4186,9 @@ async def embeddings(
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version fastapi_response.headers["x-litellm-version"] = version
fastapi_response.headers["x-litellm-model-region"] = (
user_api_key_dict.allowed_model_region or ""
)
return response return response
except Exception as e: except Exception as e:
@ -4302,6 +4358,9 @@ async def image_generation(
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version fastapi_response.headers["x-litellm-version"] = version
fastapi_response.headers["x-litellm-model-region"] = (
user_api_key_dict.allowed_model_region or ""
)
return response return response
except Exception as e: except Exception as e:
@ -4495,6 +4554,9 @@ async def audio_transcriptions(
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version fastapi_response.headers["x-litellm-version"] = version
fastapi_response.headers["x-litellm-model-region"] = (
user_api_key_dict.allowed_model_region or ""
)
return response return response
except Exception as e: except Exception as e:
@ -4670,6 +4732,9 @@ async def moderations(
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version fastapi_response.headers["x-litellm-version"] = version
fastapi_response.headers["x-litellm-model-region"] = (
user_api_key_dict.allowed_model_region or ""
)
return response return response
except Exception as e: except Exception as e:
@ -5319,6 +5384,141 @@ async def view_spend_tags(
) )
@router.get(
"/global/spend/report",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
responses={
200: {"model": List[LiteLLM_SpendLogs]},
},
)
async def get_global_spend_report(
start_date: Optional[str] = fastapi.Query(
default=None,
description="Time from which to start viewing spend",
),
end_date: Optional[str] = fastapi.Query(
default=None,
description="Time till which to view spend",
),
):
"""
Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model
[
{
"group-by-day": "2024-05-10",
"teams": [
{
"team_name": "team-1"
"spend": 10,
"keys": [
"key": "1213",
"usage": {
"model-1": {
"cost": 12.50,
"input_tokens": 1000,
"output_tokens": 5000,
"requests": 100
},
"audio-modelname1": {
"cost": 25.50,
"seconds": 25,
"requests": 50
},
}
}
]
]
}
"""
if start_date is None or end_date is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Please provide start_date and end_date"},
)
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
# first get data from spend logs -> SpendByModelApiKey
# then read data from "SpendByModelApiKey" to format the response obj
sql_query = """
WITH SpendByModelApiKey AS (
SELECT
date_trunc('day', sl."startTime") AS group_by_day,
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
sl.model,
sl.api_key,
SUM(sl.spend) AS model_api_spend,
SUM(sl.total_tokens) AS model_api_tokens
FROM
"LiteLLM_SpendLogs" sl
LEFT JOIN
"LiteLLM_TeamTable" tt
ON
sl.team_id = tt.team_id
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
GROUP BY
date_trunc('day', sl."startTime"),
tt.team_alias,
sl.model,
sl.api_key
)
SELECT
group_by_day,
jsonb_agg(jsonb_build_object(
'team_name', team_name,
'total_spend', total_spend,
'metadata', metadata
)) AS teams
FROM (
SELECT
group_by_day,
team_name,
SUM(model_api_spend) AS total_spend,
jsonb_agg(jsonb_build_object(
'model', model,
'api_key', api_key,
'spend', model_api_spend,
'total_tokens', model_api_tokens
)) AS metadata
FROM
SpendByModelApiKey
GROUP BY
group_by_day,
team_name
) AS aggregated
GROUP BY
group_by_day
ORDER BY
group_by_day;
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
return db_response
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get( @router.get(
"/global/spend/tags", "/global/spend/tags",
tags=["Budget & Spend Tracking"], tags=["Budget & Spend Tracking"],
@ -5363,6 +5563,13 @@ async def global_view_spend_tags(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
) )
if end_date is None or start_date is None:
raise ProxyException(
message="Please provide start_date and end_date",
type="bad_request",
param=None,
code=status.HTTP_400_BAD_REQUEST,
)
response = await ui_get_spend_by_tags( response = await ui_get_spend_by_tags(
start_date=start_date, end_date=end_date, prisma_client=prisma_client start_date=start_date, end_date=end_date, prisma_client=prisma_client
) )
@ -5386,6 +5593,55 @@ async def global_view_spend_tags(
) )
async def _get_spend_report_for_time_range(
start_date: str,
end_date: str,
):
global prisma_client
if prisma_client is None:
verbose_proxy_logger.error(
f"Database not connected. Connect a database to your proxy for weekly, monthly spend reports"
)
return None
try:
sql_query = """
SELECT
t.team_alias,
SUM(s.spend) AS total_spend
FROM
"LiteLLM_SpendLogs" s
LEFT JOIN
"LiteLLM_TeamTable" t ON s.team_id = t.team_id
WHERE
s."startTime"::DATE >= $1::date AND s."startTime"::DATE <= $2::date
GROUP BY
t.team_alias
ORDER BY
total_spend DESC;
"""
response = await prisma_client.db.query_raw(sql_query, start_date, end_date)
# get spend per tag for today
sql_query = """
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE "startTime"::DATE >= $1::date AND "startTime"::DATE <= $2::date
GROUP BY individual_request_tag
ORDER BY total_spend DESC;
"""
spend_per_tag = await prisma_client.db.query_raw(
sql_query, start_date, end_date
)
return response, spend_per_tag
except Exception as e:
verbose_proxy_logger.error("Exception in _get_daily_spend_reports", e) # noqa
@router.post( @router.post(
"/spend/calculate", "/spend/calculate",
tags=["Budget & Spend Tracking"], tags=["Budget & Spend Tracking"],
@ -5773,7 +6029,7 @@ async def global_spend_keys(
tags=["Budget & Spend Tracking"], tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def global_spend_per_tea(): async def global_spend_per_team():
""" """
[BETA] This is a beta endpoint. It will change. [BETA] This is a beta endpoint. It will change.
@ -9458,6 +9714,14 @@ async def health_services_endpoint(
level="Low", level="Low",
alert_type="budget_alerts", alert_type="budget_alerts",
) )
if prisma_client is not None:
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report()
)
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report()
)
return { return {
"status": "success", "status": "success",
"message": "Mock Slack Alert sent, verify Slack Alert Received on your channel", "message": "Mock Slack Alert sent, verify Slack Alert Received on your channel",

View file

@ -9,7 +9,8 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -46,6 +47,7 @@ from litellm.types.router import (
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy, RetryPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -61,7 +63,7 @@ class Router:
def __init__( def __init__(
self, self,
model_list: Optional[list] = None, model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## CACHING ## ## CACHING ##
redis_url: Optional[str] = None, redis_url: Optional[str] = None,
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
@ -82,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
@ -258,6 +263,11 @@ class Router:
self.retry_after = retry_after self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks self.fallbacks = fallbacks or litellm.fallbacks
if default_fallbacks is not None:
if self.fallbacks is not None:
self.fallbacks.append({"*": default_fallbacks})
else:
self.fallbacks = [{"*": default_fallbacks}]
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
@ -469,12 +479,30 @@ class Router:
) )
raise e raise e
# fmt: off
@overload
async def acompletion( async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
):
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
@ -606,6 +634,33 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
):
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
def image_generation(self, prompt: str, model: str, **kwargs): def image_generation(self, prompt: str, model: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model
@ -1386,7 +1441,7 @@ class Router:
verbose_router_logger.debug(f"Trying to fallback b/w models") verbose_router_logger.debug(f"Trying to fallback b/w models")
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1417,18 +1472,29 @@ class Router:
response = await self.async_function_with_retries( response = await self.async_function_with_retries(
*args, **kwargs *args, **kwargs
) )
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response return response
except Exception as e: except Exception as e:
pass pass
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks: generic_fallback_idx: Optional[int] = None
key_list = list(item.keys()) ## check for specific model group-specific fallbacks
if len(key_list) == 0: for idx, item in enumerate(fallbacks):
continue if list(item.keys())[0] == model_group:
if key_list[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
verbose_router_logger.info( verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1451,6 +1517,9 @@ class Router:
response = await self.async_function_with_fallbacks( response = await self.async_function_with_fallbacks(
*args, **kwargs *args, **kwargs
) )
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response return response
except Exception as e: except Exception as e:
raise e raise e
@ -1480,22 +1549,30 @@ class Router:
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error """
if ( Retry Logic
isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None """
) or ( _healthy_deployments = await self._async_get_healthy_deployments(
isinstance(original_exception, openai.RateLimitError) model=kwargs.get("model") or "",
and fallbacks is not None )
):
raise original_exception
### RETRY
_timeout = self._router_should_retry( # raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
# sleeps for the length of the timeout
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
if ( if (
@ -1529,10 +1606,14 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
try: try:
@ -1541,17 +1622,57 @@ class Router:
pass pass
raise original_exception raise original_exception
def should_retry_this_error(
self,
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
- there are no healthy deployments in the same model group
"""
_num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError
):
if _num_healthy_deployments <= 0:
raise error
return True
def function_with_fallbacks(self, *args, **kwargs): def function_with_fallbacks(self, *args, **kwargs):
""" """
Try calling the function_with_retries Try calling the function_with_retries
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs) response = self.function_with_retries(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -1560,7 +1681,7 @@ class Router:
try: try:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1602,10 +1723,20 @@ class Router:
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in fallbacks: generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception
@ -1629,12 +1760,27 @@ class Router:
raise e raise e
raise original_exception raise original_exception
def _router_should_retry( def _time_to_sleep_before_retry(
self, e: Exception, remaining_retries: int, num_retries: int self,
e: Exception,
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
) -> Union[int, float]: ) -> Union[int, float]:
""" """
Calculate back-off, then retry Calculate back-off, then retry
It should instantly retry only when:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
""" """
if (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 0
):
return 0
if hasattr(e, "response") and hasattr(e.response, "headers"): if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
@ -1671,23 +1817,29 @@ class Router:
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ( _healthy_deployments = self._get_healthy_deployments(
isinstance(original_exception, litellm.ContextWindowExceededError) model=kwargs.get("model"),
and context_window_fallbacks is not None )
) or (
isinstance(original_exception, openai.RateLimitError) # raises an exception if this error should not be retries
and fallbacks is not None self.should_retry_this_error(
): error=e,
raise original_exception healthy_deployments=_healthy_deployments,
## LOGGING context_window_fallbacks=context_window_fallbacks,
if num_retries > 0: )
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY # decides how long to sleep before retry
_timeout = self._router_should_retry( _timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout) time.sleep(_timeout)
for current_attempt in range(num_retries): for current_attempt in range(num_retries):
verbose_router_logger.debug( verbose_router_logger.debug(
@ -1701,11 +1853,15 @@ class Router:
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry( _timeout = self._time_to_sleep_before_retry(
e=e, e=e,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
time.sleep(_timeout) time.sleep(_timeout)
raise original_exception raise original_exception
@ -1908,6 +2064,47 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
async def _async_get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
def routing_strategy_pre_call_checks(self, deployment: dict): def routing_strategy_pre_call_checks(self, deployment: dict):
""" """
Mimics 'async_routing_strategy_pre_call_checks' Mimics 'async_routing_strategy_pre_call_checks'
@ -2339,7 +2536,7 @@ class Router:
) # cache for 1 hr ) # cache for 1 hr
else: else:
_api_key = api_key _api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str): if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key # only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15 _api_key = _api_key[:8] + "*" * 15
@ -2567,23 +2764,25 @@ class Router:
# init OpenAI, Azure clients # init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True)) self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model) # set region (if azure model) ## PREVIEW FEATURE ##
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False) if litellm.enable_preview_features == True:
if _auto_infer_region == True or _auto_infer_region == "True":
print("Auto inferring region") # noqa print("Auto inferring region") # noqa
""" """
Hiding behind a feature flag Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up When there is a large amount of LLM deployments this makes startup times blow up
""" """
try: try:
if "azure" in deployment.litellm_params.model: if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region( region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None litellm_params=deployment.litellm_params, mode=None
) )
deployment.litellm_params.region_name = region deployment.litellm_params.region_name = region
except Exception as e: except Exception as e:
verbose_router_logger.error( verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format( "Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e) deployment.litellm_params.model, str(e)
) )
@ -2961,7 +3160,7 @@ class Router:
): ):
# check if in allowed_model_region # check if in allowed_model_region
if ( if (
_is_region_eu(model_region=_litellm_params["region_name"]) _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False == False
): ):
invalid_model_indices.append(idx) invalid_model_indices.append(idx)
@ -3118,13 +3317,12 @@ class Router:
healthy_deployments.remove(deployment) healthy_deployments.remove(deployment)
# filter pre-call checks # filter pre-call checks
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if self.enable_pre_call_checks and messages is not None: if self.enable_pre_call_checks and messages is not None:
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if _allowed_model_region == "eu": if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks( healthy_deployments = self._pre_call_checks(
model=model, model=model,
@ -3145,8 +3343,10 @@ class Router:
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}" f"{RouterErrors.no_deployments_available.value}, passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
) )
if ( if (
@ -3506,7 +3706,7 @@ class Router:
) )
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_alert( proxy_logging_obj.slack_alerting_instance.send_alert(
message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}", message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_failes' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
alert_type="cooldown_deployment", alert_type="cooldown_deployment",
level="Low", level="Low",
) )

View file

@ -8,8 +8,6 @@
import dotenv, os, requests, random # type: ignore import dotenv, os, requests, random # type: ignore
from typing import Optional from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -1,12 +1,11 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random # type: ignore import os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -102,9 +101,6 @@ class LowestCostLoggingHandler(CustomLogger):
if precise_minute not in request_count_dict[id]: if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {} request_count_dict[id][precise_minute] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM ## TPM
request_count_dict[id][precise_minute]["tpm"] = ( request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -117,9 +115,6 @@ class LowestLatencyLoggingHandler(CustomLogger):
if precise_minute not in request_count_dict[id]: if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {} request_count_dict[id][precise_minute] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM ## TPM
request_count_dict[id][precise_minute]["tpm"] = ( request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens

View file

@ -4,8 +4,6 @@
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm import token_counter from litellm import token_counter
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og import datetime as datetime_og
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio, httpx import traceback, asyncio, httpx
import litellm import litellm
from litellm import token_counter from litellm import token_counter

View file

@ -228,6 +228,40 @@ async def test_langfuse_logging_without_request_response(stream, langfuse_client
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
@pytest.mark.asyncio
async def test_langfuse_masked_input_output(langfuse_client):
"""
Test that creates a trace with masked input and output
"""
import uuid
for mask_value in [True, False]:
_unique_trace_name = f"litellm-test-{str(uuid.uuid4())}"
litellm.set_verbose = True
litellm.success_callback = ["langfuse"]
response = await create_async_task(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
metadata={"trace_id": _unique_trace_name, "mask_input": mask_value, "mask_output": mask_value},
mock_response="This is a test response"
)
print(response)
expected_input = "redacted-by-litellm" if mask_value else {'messages': [{'content': 'This is a test', 'role': 'user'}]}
expected_output = "redacted-by-litellm" if mask_value else {'content': 'This is a test response', 'role': 'assistant'}
langfuse_client.flush()
await asyncio.sleep(2)
# get trace with _unique_trace_name
trace = langfuse_client.get_trace(id=_unique_trace_name)
generations = list(
reversed(langfuse_client.get_generations(trace_id=_unique_trace_name).data)
)
assert trace.input == expected_input
assert trace.output == expected_output
assert generations[0].input == expected_input
assert generations[0].output == expected_output
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_langfuse_logging_metadata(langfuse_client): async def test_langfuse_logging_metadata(langfuse_client):
""" """
@ -312,7 +346,7 @@ async def test_langfuse_logging_metadata(langfuse_client):
metadata["existing_trace_id"] = trace_id metadata["existing_trace_id"] = trace_id
langfuse_client.flush() langfuse_client.flush()
await asyncio.sleep(2) await asyncio.sleep(10)
# Tests the metadata filtering and the override of the output to be the last generation # Tests the metadata filtering and the override of the output to be the last generation
for trace_id, generation_ids in trace_identifiers.items(): for trace_id, generation_ids in trace_identifiers.items():
@ -339,6 +373,13 @@ async def test_langfuse_logging_metadata(langfuse_client):
for generation_id, generation in zip(generation_ids, generations): for generation_id, generation in zip(generation_ids, generations):
assert generation.id == generation_id assert generation.id == generation_id
assert generation.trace_id == trace_id assert generation.trace_id == trace_id
print(
"common keys in trace",
set(generation.metadata.keys()).intersection(
expected_filtered_metadata_keys
),
)
assert set(generation.metadata.keys()).isdisjoint( assert set(generation.metadata.keys()).isdisjoint(
expected_filtered_metadata_keys expected_filtered_metadata_keys
) )

View file

@ -590,19 +590,20 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
def test_gemini_pro_function_calling(): async def test_gemini_pro_function_calling(sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
response = litellm.completion( data = {
model="vertex_ai/gemini-pro", "model": "vertex_ai/gemini-pro",
messages=[ "messages": [
{ {
"role": "user", "role": "user",
"content": "Call the submit_cities function with San Francisco and New York", "content": "Call the submit_cities function with San Francisco and New York",
} }
], ],
tools=[ "tools": [
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -618,11 +619,13 @@ def test_gemini_pro_function_calling():
}, },
} }
], ],
) }
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
print(f"response: {response}") print(f"response: {response}")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except Exception as e: except Exception as e:
@ -635,73 +638,66 @@ def test_gemini_pro_function_calling():
# gemini_pro_function_calling() # gemini_pro_function_calling()
@pytest.mark.parametrize("stream", [False, True])
@pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_function_calling_streaming(stream, sync_mode): async def test_gemini_pro_function_calling_streaming(sync_mode):
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
tools = [ data = {
{ "model": "vertex_ai/gemini-pro",
"type": "function", "messages": [
"function": { {
"name": "get_current_weather", "role": "user",
"description": "Get the current weather in a given location", "content": "Call the submit_cities function with San Francisco and New York",
"parameters": { }
"type": "object", ],
"properties": { "tools": [
"location": { {
"type": "string", "type": "function",
"description": "The city and state, e.g. San Francisco, CA", "function": {
"name": "submit_cities",
"description": "Submits a list of cities",
"parameters": {
"type": "object",
"properties": {
"cities": {"type": "array", "items": {"type": "string"}}
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "required": ["cities"],
}, },
"required": ["location"],
}, },
}, }
} ],
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
optional_params = {
"tools": tools,
"tool_choice": "auto", "tool_choice": "auto",
"n": 1, "n": 1,
"stream": stream, "stream": True,
"temperature": 0.1, "temperature": 0.1,
} }
chunks = []
try: try:
if sync_mode == True: if sync_mode == True:
response = litellm.completion( response = litellm.completion(**data)
model="gemini-pro", messages=messages, **optional_params
)
print(f"completion: {response}") print(f"completion: {response}")
if stream == True: for chunk in response:
# assert completion.choices[0].message.content is None chunks.append(chunk)
# assert len(completion.choices[0].message.tool_calls) == 1 assert isinstance(chunk, litellm.ModelResponse)
for chunk in response:
assert isinstance(chunk, litellm.ModelResponse)
else:
assert isinstance(response, litellm.ModelResponse)
else: else:
response = await litellm.acompletion( response = await litellm.acompletion(**data)
model="gemini-pro", messages=messages, **optional_params
)
print(f"completion: {response}") print(f"completion: {response}")
if stream == True: assert isinstance(response, litellm.CustomStreamWrapper)
# assert completion.choices[0].message.content is None
# assert len(completion.choices[0].message.tool_calls) == 1 async for chunk in response:
async for chunk in response: print(f"chunk: {chunk}")
print(f"chunk: {chunk}") chunks.append(chunk)
assert isinstance(chunk, litellm.ModelResponse) assert isinstance(chunk, litellm.ModelResponse)
else:
assert isinstance(response, litellm.ModelResponse) complete_response = litellm.stream_chunk_builder(chunks=chunks)
assert (
complete_response.choices[0].message.content is not None
or len(complete_response.choices[0].message.tool_calls) > 0
)
print(f"complete_response: {complete_response}")
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:

View file

@ -0,0 +1,290 @@
# What is this?
## Unit test for azure content safety
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
from fastapi import HTTPException
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_strict_input_filtering_01():
"""
- have a response with a filtered input
- call the pre call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Fuck yourself you stupid bitch"},
]
}
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
assert exc_info.value.detail["source"] == "input"
assert exc_info.value.detail["category"] == "Hate"
assert exc_info.value.detail["severity"] == 2
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_strict_input_filtering_02():
"""
- have a response with a filtered input
- call the pre call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Hello how are you ?"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_loose_input_filtering_01():
"""
- have a response with a filtered input
- call the pre call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Fuck yourself you stupid bitch"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_loose_input_filtering_02():
"""
- have a response with a filtered input
- call the pre call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Hello how are you ?"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_strict_output_filtering_01():
"""
- have a response with a filtered output
- call the post call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
)
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
assert exc_info.value.detail["source"] == "output"
assert exc_info.value.detail["category"] == "Hate"
assert exc_info.value.detail["severity"] == 2
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_strict_output_filtering_02():
"""
- have a response with a filtered output
- call the post call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm unable to help with you with hate speech",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_loose_output_filtering_01():
"""
- have a response with a filtered output
- call the post call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="beta feature - local testing is failing")
async def test_loose_output_filtering_02():
"""
- have a response with a filtered output
- call the post call hook
"""
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm unable to help with you with hate speech",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)

View file

@ -26,7 +26,7 @@ model_list = [
} }
] ]
router = litellm.Router(model_list=model_list) router = litellm.Router(model_list=model_list) # type: ignore
async def _openai_completion(): async def _openai_completion():

View file

@ -206,7 +206,7 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth() # test_completion_bedrock_claude_sts_client_auth()
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set") @pytest.mark.skip(reason="We don't have Circle CI OIDC credentials as yet")
def test_completion_bedrock_claude_sts_oidc_auth(): def test_completion_bedrock_claude_sts_oidc_auth():
print("\ncalling bedrock claude with oidc auth") print("\ncalling bedrock claude with oidc auth")
import os import os

View file

@ -0,0 +1,103 @@
import sys, os
import traceback
from dotenv import load_dotenv
import asyncio, logging
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import (
embedding,
completion,
acompletion,
acreate,
completion_cost,
Timeout,
ModelResponse,
)
from litellm import RateLimitError
# litellm.num_retries = 3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
def test_completion_clarifai_claude_2_1():
print("calling clarifai claude completion")
import os
clarifai_pat = os.environ["CLARIFAI_API_KEY"]
try:
response = completion(
model="clarifai/anthropic.completion.claude-2_1",
messages=messages,
max_tokens=10,
temperature=0.1,
)
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occured: {e}")
def test_completion_clarifai_mistral_large():
try:
litellm.set_verbose = True
response: ModelResponse = completion(
model="clarifai/mistralai.completion.mistral-small",
messages=messages,
max_tokens=10,
temperature=0.78,
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
def test_async_completion_clarifai():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(
model="clarifai/openai.chat-completion.GPT-4",
messages=messages,
timeout=10,
api_key=os.getenv("CLARIFAI_API_KEY"),
)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())

View file

@ -68,6 +68,51 @@ def test_completion_custom_provider_model_name():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def _openai_mock_response(*args, **kwargs) -> litellm.ModelResponse:
_data = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": None,
"content": "\n\nHello there, how may I assist you today?",
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}
return litellm.ModelResponse(**_data)
def test_null_role_response():
"""
Test if api returns 'null' role, 'assistant' role is still returned
"""
import openai
openai_client = openai.OpenAI()
with patch.object(
openai_client.chat.completions, "create", side_effect=_openai_mock_response
) as mock_response:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
client=openai_client,
)
print(f"response: {response}")
assert response.id == "chatcmpl-123"
assert response.choices[0].message.role == "assistant"
def test_completion_azure_command_r(): def test_completion_azure_command_r():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
@ -665,6 +710,7 @@ def test_completion_mistral_api():
"content": "Hey, how's it going?", "content": "Hey, how's it going?",
} }
], ],
seed=10,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
@ -839,7 +885,7 @@ async def test_acompletion_claude2_1():
}, },
{"role": "user", "content": "Generate a 3 liner joke for me"}, {"role": "user", "content": "Generate a 3 liner joke for me"},
] ]
# test without max tokens # test without max-tokens
response = await litellm.acompletion(model="claude-2.1", messages=messages) response = await litellm.acompletion(model="claude-2.1", messages=messages)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
@ -1305,7 +1351,7 @@ def test_hf_classifier_task():
########################### End of Hugging Face Tests ############################################## ########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api(): # def test_completion_hf_api():
# # failing on circle ci commenting out # # failing on circle-ci commenting out
# try: # try:
# user_message = "write some code to find the sum of two numbers" # user_message = "write some code to find the sum of two numbers"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
@ -2584,6 +2630,69 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral() # test_completion_chat_sagemaker_mistral()
def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str)
assert response.id != ""
assert isinstance(response.object, str)
assert response.object != ""
assert isinstance(response.created, int)
assert isinstance(response.model, str)
assert response.model != ""
assert isinstance(response.choices, list)
assert len(response.choices) == 1
choice = response.choices[0]
assert isinstance(choice, litellm.Choices)
assert isinstance(choice.get("index"), int)
message = choice.get("message")
assert isinstance(message, litellm.Message)
assert isinstance(message.get("role"), str)
assert message.get("role") != ""
assert isinstance(message.get("content"), str)
assert message.get("content") != ""
assert choice.get("logprobs") is None
assert isinstance(choice.get("finish_reason"), str)
assert choice.get("finish_reason") != ""
assert isinstance(response.usage, litellm.Usage) # type: ignore
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
assert isinstance(response.usage.completion_tokens, int) # type: ignore
assert isinstance(response.usage.total_tokens, int) # type: ignore
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode):
litellm.set_verbose = True
if sync_mode:
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
response_format_tests(response=response)
else:
response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
print(f"response: {response}")
response_format_tests(response=response)
print(f"response: {response}")
def test_completion_bedrock_titan_null_response(): def test_completion_bedrock_titan_null_response():
try: try:
response = completion( response = completion(
@ -3233,6 +3342,29 @@ def test_completion_watsonx():
print(response) print(response)
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_stream_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
try:
response = completion(
model=model_name,
messages=messages,
stop=["stop"],
max_tokens=20,
stream=True,
)
for chunk in response:
print(chunk)
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -3297,6 +3429,30 @@ async def test_acompletion_watsonx():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_stream_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
print("testing watsonx")
try:
response = await litellm.acompletion(
model=model_name,
messages=messages,
temperature=0.2,
max_tokens=80,
stream=True,
)
# Add any assertions here to check the response
async for chunk in response:
print(chunk)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -5,6 +5,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import time import time
from typing import Optional
import litellm import litellm
from litellm import ( from litellm import (
get_max_tokens, get_max_tokens,
@ -12,7 +13,56 @@ from litellm import (
open_ai_chat_completion_models, open_ai_chat_completion_models,
TranscriptionResponse, TranscriptionResponse,
) )
import pytest from litellm.utils import CustomLogger
import pytest, asyncio
class CustomLoggingHandler(CustomLogger):
response_cost: Optional[float] = None
def __init__(self):
super().__init__()
def log_success_event(self, kwargs, response_obj, start_time, end_time):
self.response_cost = kwargs["response_cost"]
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"kwargs - {kwargs}")
print(f"kwargs response cost - {kwargs.get('response_cost')}")
self.response_cost = kwargs["response_cost"]
print(f"response_cost: {self.response_cost} ")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_custom_pricing(sync_mode):
new_handler = CustomLoggingHandler()
litellm.callbacks = [new_handler]
if sync_mode:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
mock_response="What do you want?",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
)
time.sleep(5)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
mock_response="What do you want?",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
)
await asyncio.sleep(5)
print(f"new_handler.response_cost: {new_handler.response_cost}")
assert new_handler.response_cost is not None
assert new_handler.response_cost == 0
def test_get_gpt3_tokens(): def test_get_gpt3_tokens():

View file

@ -5,6 +5,7 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import ConfigDict
load_dotenv() load_dotenv()
import os, io import os, io
@ -25,9 +26,7 @@ class DBModel(BaseModel):
model_name: str model_name: str
model_info: dict model_info: dict
litellm_params: dict litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
class Config:
protected_namespaces = ()
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -494,6 +494,8 @@ def test_watsonx_embeddings():
) )
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage) assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -37,14 +37,19 @@ def get_current_weather(location, unit="fahrenheit"):
# Example dummy function hard coded to return the same weather # Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API # In production, this could be your backend API or an external API
def test_parallel_function_call(): @pytest.mark.parametrize(
"model", ["gpt-3.5-turbo-1106", "mistral/mistral-large-latest"]
)
def test_parallel_function_call(model):
try: try:
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?", "content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
} }
] ]
tools = [ tools = [
@ -58,7 +63,7 @@ def test_parallel_function_call():
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state",
}, },
"unit": { "unit": {
"type": "string", "type": "string",
@ -71,7 +76,7 @@ def test_parallel_function_call():
} }
] ]
response = litellm.completion( response = litellm.completion(
model="gpt-3.5-turbo-1106", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit tool_choice="auto", # auto is default, but we'll be explicit
@ -83,8 +88,8 @@ def test_parallel_function_call():
print("length of tool calls", len(tool_calls)) print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls") print("Expecting there to be 3 tool calls")
assert ( assert (
len(tool_calls) > 1 len(tool_calls) > 0
) # this has to call the function for SF, Tokyo and parise ) # this has to call the function for SF, Tokyo and paris
# Step 2: check if the model wanted to call a function # Step 2: check if the model wanted to call a function
if tool_calls: if tool_calls:
@ -116,7 +121,7 @@ def test_parallel_function_call():
) # extend conversation with function response ) # extend conversation with function response
print(f"messages: {messages}") print(f"messages: {messages}")
second_response = litellm.completion( second_response = litellm.completion(
model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22 model=model, messages=messages, temperature=0.2, seed=22
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
return second_response return second_response

View file

@ -109,7 +109,18 @@ def mock_patch_aimage_generation():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client_no_auth(): def fake_env_vars(monkeypatch):
# Set some fake environment variables
monkeypatch.setenv("OPENAI_API_KEY", "fake_openai_api_key")
monkeypatch.setenv("OPENAI_API_BASE", "http://fake-openai-api-base")
monkeypatch.setenv("AZURE_API_BASE", "http://fake-azure-api-base")
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake_azure_openai_api_key")
monkeypatch.setenv("AZURE_SWEDEN_API_BASE", "http://fake-azure-sweden-api-base")
monkeypatch.setenv("REDIS_HOST", "localhost")
@pytest.fixture(scope="function")
def client_no_auth(fake_env_vars):
# Assuming litellm.proxy.proxy_server is an object # Assuming litellm.proxy.proxy_server is an object
from litellm.proxy.proxy_server import cleanup_router_config_variables from litellm.proxy.proxy_server import cleanup_router_config_variables
@ -495,7 +506,18 @@ def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
from litellm.proxy.proxy_server import ProxyConfig from litellm.proxy.proxy_server import ProxyConfig
def test_load_router_config(): @mock.patch("litellm.proxy.proxy_server.litellm.Cache")
def test_load_router_config(mock_cache, fake_env_vars):
mock_cache.return_value.cache.__dict__ = {"redis_client": None}
mock_cache.return_value.supported_call_types = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
try: try:
import asyncio import asyncio
@ -557,6 +579,10 @@ def test_load_router_config():
litellm.disable_cache() litellm.disable_cache()
print("testing reading proxy config for cache with params") print("testing reading proxy config for cache with params")
mock_cache.return_value.supported_call_types = [
"embedding",
"aembedding",
]
asyncio.run( asyncio.run(
proxy_config.load_config( proxy_config.load_config(
router=None, router=None,

View file

@ -134,11 +134,13 @@ async def test_router_retries(sync_mode):
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
else: else:
await router.acompletion( response = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
print(response.choices[0].message)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mistral_api_base", "mistral_api_base",
@ -687,6 +689,55 @@ def test_router_context_window_check_pre_call_check_out_group():
pytest.fail(f"Got unexpected exception on router! - {str(e)}") pytest.fail(f"Got unexpected exception on router! - {str(e)}")
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
def test_router_region_pre_call_check(allowed_model_region):
"""
If region based routing set
- check if only model in allowed region is allowed by '_pre_call_checks'
"""
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
"region_name": "eu",
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo-large", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {"id": "2"},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
_healthy_deployments = router._pre_call_checks(
model="gpt-3.5-turbo",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey!"}],
allowed_model_region=allowed_model_region,
)
if allowed_model_region is None:
assert len(_healthy_deployments) == 2
else:
assert len(_healthy_deployments) == 1, "No models selected as healthy"
assert (
_healthy_deployments[0]["model_info"]["id"] == "1"
), "Incorrect model id picked. Got id={}, expected id=1".format(
_healthy_deployments[0]["model_info"]["id"]
)
### FUNCTION CALLING ### FUNCTION CALLING

View file

@ -0,0 +1,60 @@
#### What this tests ####
# This tests litellm router with batch completion
import sys, os, time, openai
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import os, httpx
load_dotenv()
@pytest.mark.asyncio
async def test_batch_completion_multiple_models():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
print(response)
assert len(response) == 2
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]

View file

@ -83,9 +83,9 @@ def test_async_fallbacks(caplog):
# - error request, falling back notice, success notice # - error request, falling back notice, success notice
expected_logs = [ expected_logs = [
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m", "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
"litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo", "Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
"Successful fallback b/w models.",
] ]
# Assert that the captured logs match the expected log messages # Assert that the captured logs match the expected log messages

View file

@ -269,7 +269,7 @@ def test_sync_fallbacks_embeddings():
response = router.embedding(**kwargs) response = router.embedding(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}") print(f"customHandler.previous_models: {customHandler.previous_models}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback assert customHandler.previous_models == 1 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -323,7 +323,7 @@ async def test_async_fallbacks_embeddings():
await asyncio.sleep( await asyncio.sleep(
0.05 0.05
) # allow a delay as success_callbacks are on a separate thread ) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback assert customHandler.previous_models == 1 # 1 init call with a bad key
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -961,3 +961,96 @@ def test_custom_cooldown_times():
except Exception as e: except Exception as e:
print(e) print(e)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_service_unavailable_fallbacks(sync_mode):
"""
Initial model - openai
Fallback - azure
Error - 503, service unavailable
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo-012",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "anything",
"api_base": "http://0.0.0.0:8080",
},
},
{
"model_name": "gpt-3.5-turbo-0125-preview",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
],
fallbacks=[{"gpt-3.5-turbo-012": ["gpt-3.5-turbo-0125-preview"]}],
)
if sync_mode:
response = router.completion(
model="gpt-3.5-turbo-012",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
response = await router.acompletion(
model="gpt-3.5-turbo-012",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert response.model == "gpt-35-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_default_model_fallbacks(sync_mode):
"""
Related issue - https://github.com/BerriAI/litellm/issues/3623
If model misconfigured, setup a default model for generic fallback
"""
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "my-good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
default_fallbacks=["my-good-model"],
)
if sync_mode:
response = router.completion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
else:
response = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o"

Some files were not shown because too many files have changed in this diff Show more