Merge branch 'main' into litellm_end_user_obj

This commit is contained in:
Krish Dholakia 2024-05-16 14:16:09 -07:00 committed by GitHub
commit 0a775821db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
178 changed files with 11955 additions and 2346 deletions

View file

@ -58,6 +58,8 @@ jobs:
pip install python-multipart pip install python-multipart
pip install google-cloud-aiplatform pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0 pip install prometheus-client==0.20.0
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv
@ -198,6 +200,7 @@ jobs:
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e AUTO_INFER_REGION=True \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \ -e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \ -e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \

10
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,10 @@
# Add the commit hash of any commit you want to ignore in `git blame` here.
# One commit hash per line.
#
# The GitHub Blame UI will use this file automatically!
#
# Run this command to always ignore formatting commits in `git blame`
# git config blame.ignoreRevsFile .git-blame-ignore-revs
# Update pydantic code to fix warnings (GH-3600)
876840e9957bc7e9f7d6a2b58c4d7c53dad16481

View file

@ -1,6 +1,3 @@
<!-- This is just examples. You can remove all items if you want. -->
<!-- Please remove all comments. -->
## Title ## Title
<!-- e.g. "Implement user authentication feature" --> <!-- e.g. "Implement user authentication feature" -->
@ -18,7 +15,6 @@
🐛 Bug Fix 🐛 Bug Fix
🧹 Refactoring 🧹 Refactoring
📖 Documentation 📖 Documentation
💻 Development Environment
🚄 Infrastructure 🚄 Infrastructure
✅ Test ✅ Test
@ -26,22 +22,8 @@
<!-- List of changes --> <!-- List of changes -->
## Testing ## [REQUIRED] Testing - Attach a screenshot of any new tests passing locall
If UI changes, send a screenshot/GIF of working UI fixes
<!-- Test procedure --> <!-- Test procedure -->
## Notes
<!-- Test results -->
<!-- Points to note for the reviewer, consultation content, concerns -->
## Pre-Submission Checklist (optional but appreciated):
- [ ] I have included relevant documentation updates (stored in /docs/my-website)
## OS Tests (optional but appreciated):
- [ ] Tested on Windows
- [ ] Tested on MacOS
- [ ] Tested on Linux

187
cookbook/liteLLM_clarifai_Demo.ipynb vendored Normal file
View file

@ -0,0 +1,187 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LiteLLM Clarifai \n",
"This notebook walks you through on how to use liteLLM integration of Clarifai and call LLM model from clarifai with response in openAI output format."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-Requisites"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#install necessary packages\n",
"!pip install litellm\n",
"!pip install clarifai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To obtain Clarifai Personal Access Token follow the steps mentioned in the [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"## Set Clarifai Credentials\n",
"import os\n",
"os.environ[\"CLARIFAI_API_KEY\"]= \"YOUR_CLARIFAI_PAT\" # Clarifai PAT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mistral-large"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import litellm\n",
"\n",
"litellm.set_verbose=False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mistral large response : ModelResponse(id='chatcmpl-6eed494d-7ae2-4870-b9c2-6a64d50a6151', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\"In the grand tapestry of time, where tales unfold,\\nLies the chronicle of ages, a sight to behold.\\nA tale of empires rising, and kings of old,\\nOf civilizations lost, and stories untold.\\n\\nOnce upon a yesterday, in a time so vast,\\nHumans took their first steps, casting shadows in the past.\\nFrom the cradle of mankind, a journey they embarked,\\nThrough stone and bronze and iron, their skills they sharpened and marked.\\n\\nEgyptians built pyramids, reaching for the skies,\\nWhile Greeks sought wisdom, truth, in philosophies that lie.\\nRoman legions marched, their empire to expand,\\nAnd in the East, the Silk Road joined the world, hand in hand.\\n\\nThe Middle Ages came, with knights in shining armor,\\nFeudal lords and serfs, a time of both clamor and calm order.\\nThen Renaissance bloomed, like a flower in the sun,\\nA rebirth of art and science, a new age had begun.\\n\\nAcross the vast oceans, explorers sailed with courage bold,\\nDiscovering new lands, stories of adventure, untold.\\nIndustrial Revolution churned, progress in its wake,\\nMachines and factories, a whole new world to make.\\n\\nTwo World Wars raged, a testament to man's strife,\\nYet from the ashes rose hope, a renewed will for life.\\nInto the modern era, technology took flight,\\nConnecting every corner, bathed in digital light.\\n\\nHistory, a symphony, a melody of time,\\nA testament to human will, resilience so sublime.\\nIn every page, a lesson, in every tale, a guide,\\nFor understanding our past, shapes our future's tide.\", role='assistant'))], created=1713896412, model='https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=13, completion_tokens=338, total_tokens=351))\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response=completion(\n",
" model=\"clarifai/mistralai.completion.mistral-large\",\n",
" messages=messages,\n",
" )\n",
"\n",
"print(f\"Mistral large response : {response}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Claude-2.1 "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Claude-2.1 response : ModelResponse(id='chatcmpl-d126c919-4db4-4aa3-ac8f-7edea41e0b93', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\" Here's a poem I wrote about history:\\n\\nThe Tides of Time\\n\\nThe tides of time ebb and flow,\\nCarrying stories of long ago.\\nFigures and events come into light,\\nShaping the future with all their might.\\n\\nKingdoms rise, empires fall, \\nLeaving traces that echo down every hall.\\nRevolutions bring change with a fiery glow,\\nToppling structures from long ago.\\n\\nExplorers traverse each ocean and land,\\nSeeking treasures they don't understand.\\nWhile artists and writers try to make their mark,\\nHoping their works shine bright in the dark.\\n\\nThe cycle repeats again and again,\\nAs humanity struggles to learn from its pain.\\nThough the players may change on history's stage,\\nThe themes stay the same from age to age.\\n\\nWar and peace, life and death,\\nLove and strife with every breath.\\nThe tides of time continue their dance,\\nAs we join in, by luck or by chance.\\n\\nSo we study the past to light the way forward, \\nHeeding warnings from stories told and heard.\\nThe future unfolds from this unending flow -\\nWhere the tides of time ultimately go.\", role='assistant'))], created=1713896579, model='https://api.clarifai.com/v2/users/anthropic/apps/completion/models/claude-2_1/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=232, total_tokens=244))\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response=completion(\n",
" model=\"clarifai/anthropic.completion.claude-2_1\",\n",
" messages=messages,\n",
" )\n",
"\n",
"print(f\"Claude-2.1 response : {response}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### OpenAI GPT-4 (Streaming)\n",
"Though clarifai doesn't support streaming, still you can call stream and get the response in standard StreamResponse format of liteLLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"In the quiet corners of time's grand hall,\\nLies the tale of rise and fall.\\nFrom ancient ruins to modern sprawl,\\nHistory, the greatest story of them all.\\n\\nEmpires have risen, empires have decayed,\\nThrough the eons, memories have stayed.\\nIn the book of time, history is laid,\\nA tapestry of events, meticulously displayed.\\n\\nThe pyramids of Egypt, standing tall,\\nThe Roman Empire's mighty sprawl.\\nFrom Alexander's conquest, to the Berlin Wall,\\nHistory, a silent witness to it all.\\n\\nIn the shadow of the past we tread,\\nWhere once kings and prophets led.\\nTheir stories in our hearts are spread,\\nEchoes of their words, in our minds are read.\\n\\nBattles fought and victories won,\\nActs of courage under the sun.\\nTales of love, of deeds done,\\nIn history's grand book, they all run.\\n\\nHeroes born, legends made,\\nIn the annals of time, they'll never fade.\\nTheir triumphs and failures all displayed,\\nIn the eternal march of history's parade.\\n\\nThe ink of the past is forever dry,\\nBut its lessons, we cannot deny.\\nIn its stories, truths lie,\\nIn its wisdom, we rely.\\n\\nHistory, a mirror to our past,\\nA guide for the future vast.\\nThrough its lens, we're ever cast,\\nIn the drama of life, forever vast.\", role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n",
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n"
]
}
],
"source": [
"from litellm import completion\n",
"\n",
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
"response = completion(\n",
" model=\"clarifai/openai.chat-completion.GPT-4\",\n",
" messages=messages,\n",
" stream=True,\n",
" api_key = \"c75cc032415e45368be331fdd2c06db0\")\n",
"\n",
"for chunk in response:\n",
" print(chunk)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Caching - In-Memory, Redis, s3, Redis Semantic Cache # Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py) [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
::: :::
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache ## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
<Tabs> <Tabs>
@ -159,7 +159,7 @@ litellm.cache = Cache()
# Make completion calls # Make completion calls
response1 = completion( response1 = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}] messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True caching=True
) )
response2 = completion( response2 = completion(
@ -174,6 +174,43 @@ response2 = completion(
</TabItem> </TabItem>
<TabItem value="disk" label="disk cache">
### Quick Start
Install diskcache:
```shell
pip install diskcache
```
Then you can use the disk cache as follows.
```python
import litellm
from litellm import completion
from litellm.caching import Cache
litellm.cache = Cache(type="disk")
# Make completion calls
response1 = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True
)
response2 = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke."}],
caching=True
)
# response1 == response2, response 1 is cached
```
If you run the code two times, response1 will use the cache from the first run that was stored in a cache file.
</TabItem>
</Tabs> </Tabs>
@ -191,13 +228,13 @@ Advanced Params
```python ```python
litellm.enable_cache( litellm.enable_cache(
type: Optional[Literal["local", "redis"]] = "local", type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding"], ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
**kwargs, **kwargs,
) )
``` ```
@ -215,13 +252,13 @@ Update the Cache params
```python ```python
litellm.update_cache( litellm.update_cache(
type: Optional[Literal["local", "redis"]] = "local", type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding"], ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
**kwargs, **kwargs,
) )
``` ```
@ -276,22 +313,29 @@ cache.get_cache = get_cache
```python ```python
def __init__( def __init__(
self, self,
type: Optional[Literal["local", "redis", "s3"]] = "local", type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
] = ["completion", "acompletion", "embedding", "aembedding"], # A list of litellm call types to cache for. Defaults to caching for all litellm call types. ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
# redis cache params # redis cache params
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
namespace: Optional[str] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
# s3 Bucket, boto3 configuration # s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None, s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None, s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None, s3_api_version: Optional[str] = None,
s3_path: Optional[str] = None, # if you wish to save to a spefic path s3_path: Optional[str] = None, # if you wish to save to a specific path
s3_use_ssl: Optional[bool] = True, s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None, s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None, s3_endpoint_url: Optional[str] = None,
@ -299,7 +343,11 @@ def __init__(
s3_aws_secret_access_key: Optional[str] = None, s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None, s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None, s3_config: Optional[Any] = None,
**kwargs,
# disk cache params
disk_cache_dir=None,
**kwargs
): ):
``` ```

View file

@ -40,7 +40,7 @@ cache = Cache()
cache.add_cache(cache_key="test-key", result="1234") cache.add_cache(cache_key="test-key", result="1234")
cache.get_cache(cache_key="test-key) cache.get_cache(cache_key="test-key")
``` ```
## Caching with Streaming ## Caching with Streaming

View file

@ -4,6 +4,12 @@ LiteLLM allows you to:
* Send 1 completion call to many models: Return Fastest Response * Send 1 completion call to many models: Return Fastest Response
* Send 1 completion call to many models: Return All Responses * Send 1 completion call to many models: Return All Responses
:::info
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
:::
## Send multiple completion calls to 1 model ## Send multiple completion calls to 1 model
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call. In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.

View file

@ -37,11 +37,12 @@ print(response) # ["max_tokens", "tools", "tool_choice", "stream"]
This is a list of openai params we translate across providers. This is a list of openai params we translate across providers.
This list is constantly being updated. Use `litellm.get_supported_openai_params()` for an updated list of params for each model + provider
| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | | Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--| |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ | ✅ |
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | ✅ | ✅ | ✅ | ✅ |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |

View file

@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
## Custom mapping list ## Custom mapping list
Base case - we return the original exception. Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError | | custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------| |----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| watsonx | | | | | | | |✓| | | |
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |

View file

@ -0,0 +1,173 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Lago - Usage Based Billing
[Lago](https://www.getlago.com/) offers a self-hosted and cloud, metering and usage-based billing solution.
<Image img={require('../../img/lago.jpeg')} />
## Quick Start
Use just 1 lines of code, to instantly log your responses **across all providers** with Lago
Get your Lago [API Key](https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key)
```python
litellm.callbacks = ["lago"] # logs cost + usage of successful calls to lago
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python
# pip install lago
import litellm
import os
os.environ["LAGO_API_BASE"] = "" # http://0.0.0.0:3000
os.environ["LAGO_API_KEY"] = ""
os.environ["LAGO_API_EVENT_CODE"] = "" # The billable metric's code - https://docs.getlago.com/guide/events/ingesting-usage#define-a-billable-metric
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set lago as a callback, litellm will send the data to lago
litellm.success_callback = ["lago"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
],
user="your_customer_id" # 👈 SET YOUR CUSTOMER ID HERE
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to Config.yaml
```yaml
model_list:
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
litellm_settings:
callbacks: ["lago"] # 👈 KEY CHANGE
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"user": "your-customer-id" # 👈 SET YOUR CUSTOMER ID
}
'
```
</TabItem>
<TabItem value="openai_python" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
], user="my_customer_id") # 👈 whatever your customer id is
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
extra_body={
"user": "my_customer_id" # 👈 whatever your customer id is
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
<Image img={require('../../img/lago_2.png')} />
## Advanced - Lagos Logging object
This is what LiteLLM will log to Lagos
```
{
"event": {
"transaction_id": "<generated_unique_id>",
"external_customer_id": <litellm_end_user_id>, # passed via `user` param in /chat/completion call - https://platform.openai.com/docs/api-reference/chat/create
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {
"input_tokens": <number>,
"output_tokens": <number>,
"model": <string>,
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
}
}
}
```

View file

@ -136,6 +136,7 @@ response = completion(
"existing_trace_id": "trace-id22", "existing_trace_id": "trace-id22",
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata "trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value "update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
"debug_langfuse": True, # Will log the exact metadata sent to litellm for the trace/generation as `metadata_passed_to_litellm`
}, },
) )
@ -147,7 +148,7 @@ print(response)
#### Trace Specific Parameters #### Trace Specific Parameters
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default * `trace_id` - Identifier for the trace, must use `existing_trace_id` instead of `trace_id` if this is an existing trace, auto-generated by default
* `trace_name` - Name of the trace, auto-generated by default * `trace_name` - Name of the trace, auto-generated by default
* `session_id` - Session identifier for the trace, defaults to `None` * `session_id` - Session identifier for the trace, defaults to `None`
* `trace_version` - Version for the trace, defaults to value for `version` * `trace_version` - Version for the trace, defaults to value for `version`
@ -212,8 +213,20 @@ chat(messages)
## Redacting Messages, Response Content from Langfuse Logging ## Redacting Messages, Response Content from Langfuse Logging
### Redact Messages and Responses from all Langfuse Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged. Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
### Redact Messages and Responses from specific Langfuse Logging
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
Setting `mask_input` to `True` will mask the input from being logged for this call
Setting `mask_output` to `True` will make the output from being logged for this call.
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
## Troubleshooting & Errors ## Troubleshooting & Errors
### Data not getting logged to Langfuse ? ### Data not getting logged to Langfuse ?
- Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse - Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse

View file

@ -20,7 +20,7 @@ Use just 2 lines of code, to instantly log your responses **across all providers
Get your OpenMeter API Key from https://openmeter.cloud/meters Get your OpenMeter API Key from https://openmeter.cloud/meters
```python ```python
litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls to openmeter litellm.callbacks = ["openmeter"] # logs cost + usage of successful calls to openmeter
``` ```
@ -28,7 +28,7 @@ litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
```python ```python
# pip install langfuse # pip install openmeter
import litellm import litellm
import os import os
@ -39,8 +39,8 @@ os.environ["OPENMETER_API_KEY"] = ""
# LLM API Keys # LLM API Keys
os.environ['OPENAI_API_KEY']="" os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse # set openmeter as a callback, litellm will send the data to openmeter
litellm.success_callback = ["openmeter"] litellm.callbacks = ["openmeter"]
# openai call # openai call
response = litellm.completion( response = litellm.completion(
@ -64,7 +64,7 @@ model_list:
model_name: fake-openai-endpoint model_name: fake-openai-endpoint
litellm_settings: litellm_settings:
success_callback: ["openmeter"] # 👈 KEY CHANGE callbacks: ["openmeter"] # 👈 KEY CHANGE
``` ```
2. Start Proxy 2. Start Proxy

View file

@ -0,0 +1,177 @@
# Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
## Pre-Requisites
`pip install clarifai`
`pip install litellm`
## Required Environment Variables
To obtain your Clarifai Personal access token follow this [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/). Optionally the PAT can also be passed in `completion` function.
```python
os.environ["CALRIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
```
## Usage
```python
import os
from litellm import completion
os.environ["CLARIFAI_API_KEY"] = ""
response = completion(
model="clarifai/mistralai.completion.mistral-large",
messages=[{ "content": "Tell me a joke about physics?","role": "user"}]
)
```
**Output**
```json
{
"id": "chatcmpl-572701ee-9ab2-411c-ac75-46c1ba18e781",
"choices": [
{
"finish_reason": "stop",
"index": 1,
"message": {
"content": "Sure, here's a physics joke for you:\n\nWhy can't you trust an atom?\n\nBecause they make up everything!",
"role": "assistant"
}
}
],
"created": 1714410197,
"model": "https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"prompt_tokens": 14,
"completion_tokens": 24,
"total_tokens": 38
}
}
```
## Clarifai models
liteLLM supports non-streaming requests to all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24)
Example Usage - Note: liteLLM supports all models deployed on Clarifai
## Llama LLMs
| Model Name | Function Call |
---------------------------|---------------------------------|
| clarifai/meta.Llama-2.llama2-7b-chat | `completion('clarifai/meta.Llama-2.llama2-7b-chat', messages)`
| clarifai/meta.Llama-2.llama2-13b-chat | `completion('clarifai/meta.Llama-2.llama2-13b-chat', messages)`
| clarifai/meta.Llama-2.llama2-70b-chat | `completion('clarifai/meta.Llama-2.llama2-70b-chat', messages)` |
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
## Mistal LLMs
| Model Name | Function Call |
|---------------------------------------------|------------------------------------------------------------------------|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
| clarifai/mistralai.completion.mistral-large | `completion('clarifai/mistralai.completion.mistral-large', messages)` |
| clarifai/mistralai.completion.mistral-medium | `completion('clarifai/mistralai.completion.mistral-medium', messages)` |
| clarifai/mistralai.completion.mistral-small | `completion('clarifai/mistralai.completion.mistral-small', messages)` |
| clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1 | `completion('clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1', messages)`
| clarifai/mistralai.completion.mistral-7B-OpenOrca | `completion('clarifai/mistralai.completion.mistral-7B-OpenOrca', messages)` |
| clarifai/mistralai.completion.openHermes-2-mistral-7B | `completion('clarifai/mistralai.completion.openHermes-2-mistral-7B', messages)` |
## Jurassic LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/ai21.complete.Jurassic2-Grande | `completion('clarifai/ai21.complete.Jurassic2-Grande', messages)` |
| clarifai/ai21.complete.Jurassic2-Grande-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Grande-Instruct', messages)` |
| clarifai/ai21.complete.Jurassic2-Jumbo-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Jumbo-Instruct', messages)` |
| clarifai/ai21.complete.Jurassic2-Jumbo | `completion('clarifai/ai21.complete.Jurassic2-Jumbo', messages)` |
| clarifai/ai21.complete.Jurassic2-Large | `completion('clarifai/ai21.complete.Jurassic2-Large', messages)` |
## Wizard LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/wizardlm.generate.wizardCoder-Python-34B | `completion('clarifai/wizardlm.generate.wizardCoder-Python-34B', messages)` |
| clarifai/wizardlm.generate.wizardLM-70B | `completion('clarifai/wizardlm.generate.wizardLM-70B', messages)` |
| clarifai/wizardlm.generate.wizardLM-13B | `completion('clarifai/wizardlm.generate.wizardLM-13B', messages)` |
| clarifai/wizardlm.generate.wizardCoder-15B | `completion('clarifai/wizardlm.generate.wizardCoder-15B', messages)` |
## Anthropic models
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/anthropic.completion.claude-v1 | `completion('clarifai/anthropic.completion.claude-v1', messages)` |
| clarifai/anthropic.completion.claude-instant-1_2 | `completion('clarifai/anthropic.completion.claude-instant-1_2', messages)` |
| clarifai/anthropic.completion.claude-instant | `completion('clarifai/anthropic.completion.claude-instant', messages)` |
| clarifai/anthropic.completion.claude-v2 | `completion('clarifai/anthropic.completion.claude-v2', messages)` |
| clarifai/anthropic.completion.claude-2_1 | `completion('clarifai/anthropic.completion.claude-2_1', messages)` |
| clarifai/anthropic.completion.claude-3-opus | `completion('clarifai/anthropic.completion.claude-3-opus', messages)` |
| clarifai/anthropic.completion.claude-3-sonnet | `completion('clarifai/anthropic.completion.claude-3-sonnet', messages)` |
## OpenAI GPT LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/openai.chat-completion.GPT-4 | `completion('clarifai/openai.chat-completion.GPT-4', messages)` |
| clarifai/openai.chat-completion.GPT-3_5-turbo | `completion('clarifai/openai.chat-completion.GPT-3_5-turbo', messages)` |
| clarifai/openai.chat-completion.gpt-4-turbo | `completion('clarifai/openai.chat-completion.gpt-4-turbo', messages)` |
| clarifai/openai.completion.gpt-3_5-turbo-instruct | `completion('clarifai/openai.completion.gpt-3_5-turbo-instruct', messages)` |
## GCP LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/gcp.generate.gemini-1_5-pro | `completion('clarifai/gcp.generate.gemini-1_5-pro', messages)` |
| clarifai/gcp.generate.imagen-2 | `completion('clarifai/gcp.generate.imagen-2', messages)` |
| clarifai/gcp.generate.code-gecko | `completion('clarifai/gcp.generate.code-gecko', messages)` |
| clarifai/gcp.generate.code-bison | `completion('clarifai/gcp.generate.code-bison', messages)` |
| clarifai/gcp.generate.text-bison | `completion('clarifai/gcp.generate.text-bison', messages)` |
| clarifai/gcp.generate.gemma-2b-it | `completion('clarifai/gcp.generate.gemma-2b-it', messages)` |
| clarifai/gcp.generate.gemma-7b-it | `completion('clarifai/gcp.generate.gemma-7b-it', messages)` |
| clarifai/gcp.generate.gemini-pro | `completion('clarifai/gcp.generate.gemini-pro', messages)` |
| clarifai/gcp.generate.gemma-1_1-7b-it | `completion('clarifai/gcp.generate.gemma-1_1-7b-it', messages)` |
## Cohere LLMs
| Model Name | Function Call |
|-----------------------------------------------|---------------------------------------------------------------------|
| clarifai/cohere.generate.cohere-generate-command | `completion('clarifai/cohere.generate.cohere-generate-command', messages)` |
clarifai/cohere.generate.command-r-plus' | `completion('clarifai/clarifai/cohere.generate.command-r-plus', messages)`|
## Databricks LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/databricks.drbx.dbrx-instruct | `completion('clarifai/databricks.drbx.dbrx-instruct', messages)` |
| clarifai/databricks.Dolly-v2.dolly-v2-12b | `completion('clarifai/databricks.Dolly-v2.dolly-v2-12b', messages)`|
## Microsoft LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/microsoft.text-generation.phi-2 | `completion('clarifai/microsoft.text-generation.phi-2', messages)` |
| clarifai/microsoft.text-generation.phi-1_5 | `completion('clarifai/microsoft.text-generation.phi-1_5', messages)`|
## Salesforce models
| Model Name | Function Call |
|-----------------------------------------------------------|-------------------------------------------------------------------------------|
| clarifai/salesforce.blip.general-english-image-caption-blip-2 | `completion('clarifai/salesforce.blip.general-english-image-caption-blip-2', messages)` |
| clarifai/salesforce.xgen.xgen-7b-8k-instruct | `completion('clarifai/salesforce.xgen.xgen-7b-8k-instruct', messages)` |
## Other Top performing LLMs
| Model Name | Function Call |
|---------------------------------------------------|---------------------------------------------------------------------|
| clarifai/deci.decilm.deciLM-7B-instruct | `completion('clarifai/deci.decilm.deciLM-7B-instruct', messages)` |
| clarifai/upstage.solar.solar-10_7b-instruct | `completion('clarifai/upstage.solar.solar-10_7b-instruct', messages)` |
| clarifai/openchat.openchat.openchat-3_5-1210 | `completion('clarifai/openchat.openchat.openchat-3_5-1210', messages)` |
| clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B | `completion('clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B', messages)` |
| clarifai/fblgit.una-cybertron.una-cybertron-7b-v2 | `completion('clarifai/fblgit.una-cybertron.una-cybertron-7b-v2', messages)` |
| clarifai/tiiuae.falcon.falcon-40b-instruct | `completion('clarifai/tiiuae.falcon.falcon-40b-instruct', messages)` |
| clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat | `completion('clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat', messages)` |
| clarifai/bigcode.code.StarCoder | `completion('clarifai/bigcode.code.StarCoder', messages)` |
| clarifai/mosaicml.mpt.mpt-7b-instruct | `completion('clarifai/mosaicml.mpt.mpt-7b-instruct', messages)` |

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs> <Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)"> <TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem> </TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)"> <TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/facebook/blenderbot-400M-distill", model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages, messages=messages,
api_base="https://my-endpoint.huggingface.cloud" api_base="https://my-endpoint.huggingface.cloud"
) )
@ -62,7 +116,123 @@ response = completion(
print(response) print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs"> <TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python ```python
import os import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/roneneldan/TinyStories-3M", model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages, messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
) )

View file

@ -102,12 +102,18 @@ Ollama supported models: https://github.com/ollama/ollama
| Model Name | Function Call | | Model Name | Function Call |
|----------------------|----------------------------------------------------------------------------------- |----------------------|-----------------------------------------------------------------------------------
| Mistral | `completion(model='ollama/mistral', messages, api_base="http://localhost:11434", stream=True)` | | Mistral | `completion(model='ollama/mistral', messages, api_base="http://localhost:11434", stream=True)` |
| Mistral-7B-Instruct-v0.1 | `completion(model='ollama/mistral-7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Mistral-7B-Instruct-v0.2 | `completion(model='ollama/mistral-7B-Instruct-v0.2', messages, api_base="http://localhost:11434", stream=False)` |
| Mixtral-8x7B-Instruct-v0.1 | `completion(model='ollama/mistral-8x7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Mixtral-8x22B-Instruct-v0.1 | `completion(model='ollama/mixtral-8x22B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Llama2 7B | `completion(model='ollama/llama2', messages, api_base="http://localhost:11434", stream=True)` | | Llama2 7B | `completion(model='ollama/llama2', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 13B | `completion(model='ollama/llama2:13b', messages, api_base="http://localhost:11434", stream=True)` | | Llama2 13B | `completion(model='ollama/llama2:13b', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 70B | `completion(model='ollama/llama2:70b', messages, api_base="http://localhost:11434", stream=True)` | | Llama2 70B | `completion(model='ollama/llama2:70b', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` | | Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
| Code Llama | `completion(model='ollama/codellama', messages, api_base="http://localhost:11434", stream=True)` | | Code Llama | `completion(model='ollama/codellama', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` | | Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|Meta LLaMa3 8B | `completion(model='ollama/llama3', messages, api_base="http://localhost:11434", stream=False)` |
| Meta LLaMa3 70B | `completion(model='ollama/llama3:70b', messages, api_base="http://localhost:11434", stream=False)` |
| Orca Mini | `completion(model='ollama/orca-mini', messages, api_base="http://localhost:11434", stream=True)` | | Orca Mini | `completion(model='ollama/orca-mini', messages, api_base="http://localhost:11434", stream=True)` |
| Vicuna | `completion(model='ollama/vicuna', messages, api_base="http://localhost:11434", stream=True)` | | Vicuna | `completion(model='ollama/vicuna', messages, api_base="http://localhost:11434", stream=True)` |
| Nous-Hermes | `completion(model='ollama/nous-hermes', messages, api_base="http://localhost:11434", stream=True)` | | Nous-Hermes | `completion(model='ollama/nous-hermes', messages, api_base="http://localhost:11434", stream=True)` |

View file

@ -20,7 +20,7 @@ os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call # openai call
response = completion( response = completion(
model = "gpt-3.5-turbo", model = "gpt-4o",
messages=[{ "content": "Hello, how are you?","role": "user"}] messages=[{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
@ -163,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
| gpt-4o-2024-05-13 | `response = completion(model="gpt-4o-2024-05-13", messages=messages)` |
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` | | gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` | | gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` | | gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
@ -186,6 +188,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
## OpenAI Vision Models ## OpenAI Vision Models
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` | | gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` | | gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |

View file

@ -0,0 +1,95 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Triton Inference Server
LiteLLM supports Embedding Models on Triton Inference Servers
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### Example Call
Use the `triton/` prefix to route to triton server
```python
from litellm import embedding
import os
response = await litellm.aembedding(
model="triton/<your-triton-model>",
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
input=["good morning from litellm"],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/embeddings
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
input=["hello from litellm"],
model="my-triton-model"
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/embeddings' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"input": ["write a litellm poem"]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

View file

@ -364,6 +364,8 @@ response = completion(
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` | | gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
| gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
| gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` |

View file

@ -1,13 +1,18 @@
# 🚨 Alerting # 🚨 Alerting
Get alerts for: Get alerts for:
- Hanging LLM api calls - Hanging LLM api calls
- Failed LLM api calls
- Slow LLM api calls - Slow LLM api calls
- Budget Tracking per key/user: - Failed LLM api calls
- When a User/Key crosses their Budget - Budget Tracking per key/user
- When a User/Key is 15% away from crossing their Budget - Spend Reports - Weekly & Monthly spend per Team, Tag
- Failed db read/writes - Failed db read/writes
- Daily Reports:
- **LLM** Top 5 slowest deployments
- **LLM** Top 5 deployments with most failed requests
- **Spend** Weekly & Monthly spend per Team, Tag
## Quick Start ## Quick Start
@ -17,10 +22,12 @@ Set up a slack alert channel to receive alerts from proxy.
Get a slack webhook url from https://api.slack.com/messaging/webhooks Get a slack webhook url from https://api.slack.com/messaging/webhooks
You can also use Discord Webhooks, see [here](#using-discord-webhooks)
### Step 2: Update config.yaml ### Step 2: Update config.yaml
Let's save a bad key to our proxy - Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
- Just for testing purposes, let's save a bad key to our proxy.
```yaml ```yaml
model_list: model_list:
@ -33,16 +40,59 @@ general_settings:
alerting: ["slack"] alerting: ["slack"]
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+ alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
environment_variables:
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
``` ```
Set `SLACK_WEBHOOK_URL` in your proxy env
```shell
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
```
### Step 3: Start proxy ### Step 3: Start proxy
```bash ```bash
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
## Testing Alerting is Setup Correctly
Make a GET request to `/health/services`, expect to see a test slack alert in your provided webhook slack channel
```shell
curl -X GET 'http://localhost:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234'
```
## Extras
### Using Discord Webhooks
Discord provides a slack compatible webhook url that you can use for alerting
##### Quick Start
1. Get a webhook url for your discord channel
2. Append `/slack` to your discord webhook - it should look like
```
"https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
```
3. Add it to your litellm config
```yaml
model_list:
model_name: "azure-model"
litellm_params:
model: "azure/gpt-35-turbo"
api_key: "my-bad-key" # 👈 bad key
general_settings:
alerting: ["slack"]
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
environment_variables:
SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
```
That's it ! You're ready to go !

View file

@ -0,0 +1,229 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💰 Billing
Bill users for their usage.
Requirements:
- Setup a billing plan on Lago, for usage-based billing. We recommend following their Stripe tutorial - https://docs.getlago.com/templates/per-transaction/stripe#step-1-create-billable-metrics-for-transaction
Steps:
- Connect the proxy to Lago
- Set the id you want to bill for (customers, internal users, teams)
- Start!
## 1. Connect proxy to Lago
Add your Lago keys to the environment
```bash
export LAGO_API_BASE="http://localhost:3000" # self-host - https://docs.getlago.com/guide/self-hosted/docker#run-the-app
export LAGO_API_KEY="3e29d607-de54-49aa-a019-ecf585729070" # Get key - https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key
export LAGO_API_EVENT_CODE="openai_tokens" # name of lago billing code
```
Set 'lago' as a callback on your proxy config.yaml
```yaml
...
litellm_settings:
callbacks: ["lago"]
```
## 2. Set the id you want to bill for
For:
- Customers (id passed via 'user' param in /chat/completion call) = 'end_user_id'
- Internal Users (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'user_id'
- Teams (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'team_id'
```yaml
export LAGO_API_CHARGE_BY="end_user_id" # 👈 Charge 'Customers'. Default is 'end_user_id'.
```
## 3. Start billing!
<Tabs>
<TabItem value="customers" label="Customer Billing">
### **Curl**
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"user": "my_customer_id" # 👈 whatever your customer id is
}
'
```
### **OpenAI Python SDK**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
], user="my_customer_id") # 👈 whatever your customer id is
print(response)
```
### **Langchain**
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
extra_body={
"user": "my_customer_id" # 👈 whatever your customer id is
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
<TabItem value="internal_user" label="Internal User (Key Owner) Billing">
1. Create a key for that user
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{"user_id": "my-unique-id"}'
```
Response Object:
```bash
{
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
}
```
2. Make API Calls with that Key
```python
import openai
client = openai.OpenAI(
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="teams" label="Team Billing">
1. Create a key for that team
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{"team_id": "my-unique-id"}'
```
Response Object:
```bash
{
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
}
```
2. Make API Calls with that Key
```python
import openai
client = openai.OpenAI(
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
</Tabs>
**See Results on Lago**
<Image img={require('../../img/lago_2.png')} style={{ width: '500px', height: 'auto' }} />
## Advanced - Lago Logging object
This is what LiteLLM will log to Lagos
```
{
"event": {
"transaction_id": "<generated_unique_id>",
"external_customer_id": <selected_id>, # either 'end_user_id', 'user_id', or 'team_id'. Default 'end_user_id'.
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {
"input_tokens": <number>,
"output_tokens": <number>,
"model": <string>,
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
}
}
}
```

View file

@ -1,8 +1,161 @@
# Cost Tracking - Azure import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💸 Spend Tracking
Track spend for keys, users, and teams across 100+ LLMs.
## Getting Spend Reports - To Charge Other Teams, API Keys
Use the `/global/spend/report` endpoint to get daily spend per team, with a breakdown of spend per API Key, Model
### Example Request
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
-H 'Authorization: Bearer sk-1234'
```
### Example Response
<Tabs>
<TabItem value="response" label="Expected Response">
```shell
[
{
"group_by_day": "2024-04-30T00:00:00+00:00",
"teams": [
{
"team_name": "Prod Team",
"total_spend": 0.0015265,
"metadata": [ # see the spend by unique(key + model)
{
"model": "gpt-4",
"spend": 0.00123,
"total_tokens": 28,
"api_key": "88dc28.." # the hashed api key
},
{
"model": "gpt-4",
"spend": 0.00123,
"total_tokens": 28,
"api_key": "a73dc2.." # the hashed api key
},
{
"model": "chatgpt-v-2",
"spend": 0.000214,
"total_tokens": 122,
"api_key": "898c28.." # the hashed api key
},
{
"model": "gpt-3.5-turbo",
"spend": 0.0000825,
"total_tokens": 85,
"api_key": "84dc28.." # the hashed api key
}
]
}
]
}
]
```
</TabItem>
<TabItem value="py-script" label="Script to Parse Response (Python)">
```python
import requests
url = 'http://localhost:4000/global/spend/report'
params = {
'start_date': '2023-04-01',
'end_date': '2024-06-30'
}
headers = {
'Authorization': 'Bearer sk-1234'
}
# Make the GET request
response = requests.get(url, headers=headers, params=params)
spend_report = response.json()
for row in spend_report:
date = row["group_by_day"]
teams = row["teams"]
for team in teams:
team_name = team["team_name"]
total_spend = team["total_spend"]
metadata = team["metadata"]
print(f"Date: {date}")
print(f"Team: {team_name}")
print(f"Total Spend: {total_spend}")
print("Metadata: ", metadata)
print()
```
Output from script
```shell
# Date: 2024-05-11T00:00:00+00:00
# Team: local_test_team
# Total Spend: 0.003675099999999999
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.003675099999999999, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 3105}]
# Date: 2024-05-13T00:00:00+00:00
# Team: Unassigned Team
# Total Spend: 3.4e-05
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 3.4e-05, 'api_key': '9569d13c9777dba68096dea49b0b03e0aaf4d2b65d4030eda9e8a2733c3cd6e0', 'total_tokens': 50}]
# Date: 2024-05-13T00:00:00+00:00
# Team: central
# Total Spend: 0.000684
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.000684, 'api_key': '0323facdf3af551594017b9ef162434a9b9a8ca1bbd9ccbd9d6ce173b1015605', 'total_tokens': 498}]
# Date: 2024-05-13T00:00:00+00:00
# Team: local_test_team
# Total Spend: 0.0005715000000000001
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.0005715000000000001, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 423}]
```
</TabItem>
</Tabs>
## Reset Team, API Key Spend - MASTER KEY ONLY
Use `/global/spend/reset` if you want to:
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
### Request
Only the `LITELLM_MASTER_KEY` you set can access this route
```shell
curl -X POST \
'http://localhost:4000/global/spend/reset' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json'
```
### Expected Responses
```shell
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
```
## Spend Tracking for Azure
Set base model for cost tracking azure image-gen call Set base model for cost tracking azure image-gen call
## Image Generation ### Image Generation
```yaml ```yaml
model_list: model_list:
@ -17,7 +170,7 @@ model_list:
mode: image_generation mode: image_generation
``` ```
## Chat Completions / Embeddings ### Chat Completions / Embeddings
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking **Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking

View file

@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina # 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Logging to Sentry](#logging-proxy-inputoutput---sentry) - [Logging to Sentry](#logging-proxy-inputoutput---sentry)
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry) - [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
- [Logging to Athina](#logging-proxy-inputoutput-athina) - [Logging to Athina](#logging-proxy-inputoutput-athina)
- [(BETA) Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
## Custom Callback Class [Async] ## Custom Callback Class [Async]
Use this when you want to run custom callbacks in `python` Use this when you want to run custom callbacks in `python`
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
] ]
}' }'
``` ```
## (BETA) Moderation with Azure Content Safety
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
**Step 0** Deploy Azure Content Safety
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
**Step 1** Set Athina API key
```shell
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hi, how are you?"
}
]
}'
```
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
The details of the response will describe :
- The `source` : input text or llm generated text
- The `category` : the category of the content that triggered the moderation
- The `severity` : the severity from 0 to 10
**Step 4**: Customizing Azure Content Safety Thresholds
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
thresholds:
Hate: 6
SelfHarm: 8
Sexual: 6
Violence: 4
```
:::info
`thresholds` are not required by default, but you can tune the values to your needs.
Default values is `4` for all categories
:::

View file

@ -64,6 +64,12 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
``` ```
## 4. Disable 'load_dotenv'
Set `export LITELLM_MODE="PRODUCTION"`
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
## Extras ## Extras
### Expected Performance in Production ### Expected Performance in Production

View file

@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}' }'
``` ```
## Advanced - Context Window Fallbacks ## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**. **Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
@ -287,6 +287,69 @@ print(response)
</Tabs> </Tabs>
## Advanced - EU-Region Filtering (Pre-Call Checks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
Set 'region_name' of deployment.
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
**1. Set Config**
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
region_name: "eu" # 👈 SET EU-REGION
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
- model_name: gemini-pro
litellm_params:
model: vertex_ai/gemini-pro-1.5
vertex_project: adroit-crow-1234
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages = [{"role": "user", "content": "Who was Alexander?"}]
)
print(response)
print(f"response.headers.get('x-litellm-model-api-base')")
```
## Advanced - Custom Timeouts, Stream Timeouts - Per Model ## Advanced - Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params` For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml ```yaml

View file

@ -110,7 +110,7 @@ general_settings:
admin_jwt_scope: "litellm-proxy-admin" admin_jwt_scope: "litellm-proxy-admin"
``` ```
## Advanced - Spend Tracking (User / Team / Org) ## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
Set the field in the jwt token, which corresponds to a litellm user / team / org. Set the field in the jwt token, which corresponds to a litellm user / team / org.
@ -123,6 +123,7 @@ general_settings:
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
``` ```
Expected JWT: Expected JWT:
@ -131,7 +132,7 @@ Expected JWT:
{ {
"client_id": "my-unique-team", "client_id": "my-unique-team",
"sub": "my-unique-user", "sub": "my-unique-user",
"org_id": "my-unique-org" "org_id": "my-unique-org",
} }
``` ```

View file

@ -365,6 +365,188 @@ curl --location 'http://0.0.0.0:4000/moderations' \
## Advanced ## Advanced
### (BETA) Batch Completions - pass multiple models
Use this when you want to send 1 request to N Models
#### Expected Request Format
Pass model as a string of comma separated value of models. Example `"model"="llama3,gpt-3.5-turbo"`
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
- `model_name="llama3"`
- `model_name="gpt-3.5-turbo"`
<Tabs>
<TabItem value="openai-py" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.chat.completions.create(
model="gpt-3.5-turbo,llama3",
messages=[
{"role": "user", "content": "this is a test request, write a short poem"}
],
)
print(response)
```
#### Expected Response Format
Get a list of responses when `model` is passed as a list
```python
[
ChatCompletion(
id='chatcmpl-9NoYhS2G0fswot0b6QpoQgmRQMaIf',
choices=[
Choice(
finish_reason='stop',
index=0,
logprobs=None,
message=ChatCompletionMessage(
content='In the depths of my soul, a spark ignites\nA light that shines so pure and bright\nIt dances and leaps, refusing to die\nA flame of hope that reaches the sky\n\nIt warms my heart and fills me with bliss\nA reminder that in darkness, there is light to kiss\nSo I hold onto this fire, this guiding light\nAnd let it lead me through the darkest night.',
role='assistant',
function_call=None,
tool_calls=None
)
)
],
created=1715462919,
model='gpt-3.5-turbo-0125',
object='chat.completion',
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=83,
prompt_tokens=17,
total_tokens=100
)
),
ChatCompletion(
id='chatcmpl-4ac3e982-da4e-486d-bddb-ed1d5cb9c03c',
choices=[
Choice(
finish_reason='stop',
index=0,
logprobs=None,
message=ChatCompletionMessage(
content="A test request, and I'm delighted!\nHere's a short poem, just for you:\n\nMoonbeams dance upon the sea,\nA path of light, for you to see.\nThe stars up high, a twinkling show,\nA night of wonder, for all to know.\n\nThe world is quiet, save the night,\nA peaceful hush, a gentle light.\nThe world is full, of beauty rare,\nA treasure trove, beyond compare.\n\nI hope you enjoyed this little test,\nA poem born, of whimsy and jest.\nLet me know, if there's anything else!",
role='assistant',
function_call=None,
tool_calls=None
)
)
],
created=1715462919,
model='groq/llama3-8b-8192',
object='chat.completion',
system_fingerprint='fp_a2c8d063cb',
usage=CompletionUsage(
completion_tokens=120,
prompt_tokens=20,
total_tokens=140
)
)
]
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3,gpt-3.5-turbo",
"max_tokens": 10,
"user": "litellm2",
"messages": [
{
"role": "user",
"content": "is litellm getting better"
}
]
}'
```
#### Expected Response Format
Get a list of responses when `model` is passed as a list
```json
[
{
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
"role": "assistant"
}
}
],
"created": 1715459876,
"model": "groq/llama3-8b-8192",
"object": "chat.completion",
"system_fingerprint": "fp_179b0f92c9",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 12,
"total_tokens": 22
}
},
{
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "TES4 could refer to The Elder Scrolls IV:",
"role": "assistant"
}
}
],
"created": 1715459877,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"completion_tokens": 10,
"prompt_tokens": 9,
"total_tokens": 19
}
}
]
```
</TabItem>
</Tabs>
### Pass User LLM API Keys, Fallbacks ### Pass User LLM API Keys, Fallbacks
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests

View file

@ -653,7 +653,9 @@ from litellm import Router
model_list = [{...}] model_list = [{...}]
router = Router(model_list=model_list, router = Router(model_list=model_list,
allowed_fails=1) # cooldown model if it fails > 1 call in a minute. allowed_fails=1, # cooldown model if it fails > 1 call in a minute.
cooldown_time=100 # cooldown the deployment for 100 seconds if it num_fails > allowed_fails
)
user_message = "Hello, whats the weather in San Francisco??" user_message = "Hello, whats the weather in San Francisco??"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
@ -770,6 +772,8 @@ If the error is a context window exceeded error, fall back to a larger model gro
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc. Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad.
```python ```python
from litellm import Router from litellm import Router
@ -830,6 +834,7 @@ model_list = [
router = Router(model_list=model_list, router = Router(model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
default_fallbacks=["gpt-3.5-turbo-16k"],
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
set_verbose=True) set_verbose=True)
@ -879,13 +884,11 @@ router = Router(model_list: Optional[list] = None,
cache_responses=True) cache_responses=True)
``` ```
## Pre-Call Checks (Context Window) ## Pre-Call Checks (Context Window, EU-Regions)
Enable pre-call checks to filter out: Enable pre-call checks to filter out:
1. deployments with context window limit < messages for a call. 1. deployments with context window limit < messages for a call.
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[ 2. deployments outside of eu-region
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
])`)
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -900,10 +903,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
**2. Set Model List** **2. Set Model List**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`. For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
<Tabs> For 'eu-region' filtering, Set 'region_name' of deployment.
<TabItem value="same-group" label="Same Group">
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
```python ```python
model_list = [ model_list = [
@ -914,10 +921,9 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
}, "region_name": "eu" # 👈 SET 'EU' REGION NAME
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL "base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
} },
}, },
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -926,54 +932,26 @@ model_list = [
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
}, },
{
"model_name": "gemini-pro",
"litellm_params: {
"model": "vertex_ai/gemini-pro-1.5",
"vertex_project": "adroit-crow-1234",
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
}
}
] ]
router = Router(model_list=model_list, enable_pre_call_checks=True) router = Router(model_list=model_list, enable_pre_call_checks=True)
``` ```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
```python
model_list = [
{
"model_name": "gpt-3.5-turbo-small", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
"model_name": "gpt-3.5-turbo-large", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-opus",
"litellm_params": { call
"model": "claude-3-opus-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
```
</TabItem>
</Tabs>
**3. Test it!** **3. Test it!**
<Tabs>
<TabItem value="context-window-check" label="Context Window Check">
```python ```python
""" """
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k) - Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
@ -983,7 +961,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
from litellm import Router from litellm import Router
import os import os
try:
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -992,6 +969,7 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
}, },
"model_info": { "model_info": {
"base_model": "azure/gpt-35-turbo", "base_model": "azure/gpt-35-turbo",
@ -1021,6 +999,59 @@ response = router.completion(
print(f"response: {response}") print(f"response: {response}")
``` ```
</TabItem> </TabItem>
<TabItem value="eu-region-check" label="EU Region Check">
```python
"""
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
- Make a call
- Assert it picks the eu-region model
"""
from litellm import Router
import os
model_list = [
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"region_name": "eu"
},
"model_info": {
"id": "1"
}
},
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {
"id": "2"
}
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Who was Alexander?"}],
)
print(f"response: {response}")
print(f"response id: {response._hidden_params['model_id']}")
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="proxy" label="Proxy"> <TabItem value="proxy" label="Proxy">
:::info :::info
@ -1283,10 +1314,11 @@ def __init__(
num_retries: int = 0, num_retries: int = 0,
timeout: Optional[float] = None, timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create default_litellm_params={}, # default params for Router.chat.completion.create
fallbacks: List = [], fallbacks: Optional[List] = None,
default_fallbacks: Optional[List] = None
allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
context_window_fallbacks: List = [], context_window_fallbacks: Optional[List] = None,
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
retry_after: int = 0, # (min) time to wait before retrying a failed request retry_after: int = 0, # (min) time to wait before retrying a failed request
routing_strategy: Literal[ routing_strategy: Literal[

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

View file

@ -39,7 +39,9 @@ const sidebars = {
"proxy/demo", "proxy/demo",
"proxy/configs", "proxy/configs",
"proxy/reliability", "proxy/reliability",
"proxy/cost_tracking",
"proxy/users", "proxy/users",
"proxy/billing",
"proxy/user_keys", "proxy/user_keys",
"proxy/enterprise", "proxy/enterprise",
"proxy/virtual_keys", "proxy/virtual_keys",
@ -52,7 +54,6 @@ const sidebars = {
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/customer_routing", "proxy/customer_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking",
"proxy/token_auth", "proxy/token_auth",
{ {
type: "category", type: "category",
@ -134,6 +135,7 @@ const sidebars = {
"providers/huggingface", "providers/huggingface",
"providers/watsonx", "providers/watsonx",
"providers/predibase", "providers/predibase",
"providers/triton-inference-server",
"providers/ollama", "providers/ollama",
"providers/perplexity", "providers/perplexity",
"providers/groq", "providers/groq",
@ -174,6 +176,7 @@ const sidebars = {
"observability/custom_callback", "observability/custom_callback",
"observability/langfuse_integration", "observability/langfuse_integration",
"observability/sentry", "observability/sentry",
"observability/lago",
"observability/openmeter", "observability/openmeter",
"observability/promptlayer_integration", "observability/promptlayer_integration",
"observability/wandb_integration", "observability/wandb_integration",
@ -188,7 +191,7 @@ const sidebars = {
`observability/telemetry`, `observability/telemetry`,
], ],
}, },
"caching/redis_cache", "caching/all_caches",
{ {
type: "category", type: "category",
label: "Tutorials", label: "Tutorials",

View file

@ -10,7 +10,6 @@ from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -19,8 +18,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -1,6 +1,7 @@
# Enterprise Proxy Util Endpoints # Enterprise Proxy Util Endpoints
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
import collections import collections
from datetime import datetime
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
@ -18,26 +19,33 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
return response return response
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def ui_get_spend_by_tags(start_date: str, end_date: str, prisma_client):
response = await prisma_client.db.query_raw(
""" sql_query = """
SELECT SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag, jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date, DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count, COUNT(*) AS log_count,
SUM(spend) AS total_spend SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s FROM "LiteLLM_SpendLogs" s
WHERE s."startTime" >= current_date - interval '30 days' WHERE
DATE(s."startTime") >= $1::date
AND DATE(s."startTime") <= $2::date
GROUP BY individual_request_tag, spend_date GROUP BY individual_request_tag, spend_date
ORDER BY spend_date; ORDER BY spend_date
LIMIT 100;
""" """
response = await prisma_client.db.query_raw(
sql_query,
start_date,
end_date,
) )
# print("tags - spend") # print("tags - spend")
# print(response) # print(response)
# Bar Chart 1 - Spend per tag - Top 10 tags by spend # Bar Chart 1 - Spend per tag - Top 10 tags by spend
total_spend_per_tag = collections.defaultdict(float) total_spend_per_tag: collections.defaultdict = collections.defaultdict(float)
total_requests_per_tag = collections.defaultdict(int) total_requests_per_tag: collections.defaultdict = collections.defaultdict(int)
for row in response: for row in response:
tag_name = row["individual_request_tag"] tag_name = row["individual_request_tag"]
tag_spend = row["total_spend"] tag_spend = row["total_spend"]
@ -49,15 +57,18 @@ async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=Non
# convert to ui format # convert to ui format
ui_tags = [] ui_tags = []
for tag in sorted_tags: for tag in sorted_tags:
current_spend = tag[1]
if current_spend is not None and isinstance(current_spend, float):
current_spend = round(current_spend, 4)
ui_tags.append( ui_tags.append(
{ {
"name": tag[0], "name": tag[0],
"value": tag[1], "spend": current_spend,
"log_count": total_requests_per_tag[tag[0]], "log_count": total_requests_per_tag[tag[0]],
} }
) )
return {"top_10_tags": ui_tags} return {"spend_per_tag": ui_tags}
async def view_spend_logs_from_clickhouse( async def view_spend_logs_from_clickhouse(

View file

@ -1,5 +1,6 @@
### Hide pydantic namespace conflict warnings globally ### ### Hide pydantic namespace conflict warnings globally ###
import warnings import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
@ -14,7 +15,9 @@ from litellm.proxy._types import (
import httpx import httpx
import dotenv import dotenv
dotenv.load_dotenv() litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
if litellm_mode == "DEV":
dotenv.load_dotenv()
############################################# #############################################
if set_verbose == True: if set_verbose == True:
_turn_on_debug() _turn_on_debug()
@ -24,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] _custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
_custom_logger_compatible_callbacks: list = ["openmeter"] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
Literal[ Literal[
@ -70,6 +73,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
clarifai_key: Optional[str] = None
maritalk_key: Optional[str] = None maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None ollama_key: Optional[str] = None
@ -100,6 +104,9 @@ blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
################## ##################
### PREVIEW FEATURES ###
enable_preview_features: bool = False
##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
@ -214,6 +221,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0 allowed_fails: int = 0
@ -400,6 +408,73 @@ replicate_models: List = [
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
] ]
clarifai_models: List = [
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
"clarifai/gcp.generate.gemma-1_1-7b-it",
"clarifai/mistralai.completion.mixtral-8x22B",
"clarifai/cohere.generate.command-r-plus",
"clarifai/databricks.drbx.dbrx-instruct",
"clarifai/mistralai.completion.mistral-large",
"clarifai/mistralai.completion.mistral-medium",
"clarifai/mistralai.completion.mistral-small",
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
"clarifai/gcp.generate.gemma-2b-it",
"clarifai/gcp.generate.gemma-7b-it",
"clarifai/deci.decilm.deciLM-7B-instruct",
"clarifai/mistralai.completion.mistral-7B-Instruct",
"clarifai/gcp.generate.gemini-pro",
"clarifai/anthropic.completion.claude-v1",
"clarifai/anthropic.completion.claude-instant-1_2",
"clarifai/anthropic.completion.claude-instant",
"clarifai/anthropic.completion.claude-v2",
"clarifai/anthropic.completion.claude-2_1",
"clarifai/meta.Llama-2.codeLlama-70b-Python",
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
"clarifai/meta.Llama-2.llama2-7b-chat",
"clarifai/meta.Llama-2.llama2-13b-chat",
"clarifai/meta.Llama-2.llama2-70b-chat",
"clarifai/openai.chat-completion.gpt-4-turbo",
"clarifai/microsoft.text-generation.phi-2",
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
"clarifai/upstage.solar.solar-10_7b-instruct",
"clarifai/openchat.openchat.openchat-3_5-1210",
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
"clarifai/gcp.generate.text-bison",
"clarifai/meta.Llama-2.llamaGuard-7b",
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
"clarifai/openai.chat-completion.GPT-4",
"clarifai/openai.chat-completion.GPT-3_5-turbo",
"clarifai/ai21.complete.Jurassic2-Grande",
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo",
"clarifai/ai21.complete.Jurassic2-Large",
"clarifai/cohere.generate.cohere-generate-command",
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
"clarifai/wizardlm.generate.wizardLM-70B",
"clarifai/tiiuae.falcon.falcon-40b-instruct",
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
"clarifai/gcp.generate.code-gecko",
"clarifai/gcp.generate.code-bison",
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
"clarifai/wizardlm.generate.wizardLM-13B",
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
"clarifai/wizardlm.generate.wizardCoder-15B",
"clarifai/microsoft.text-generation.phi-1_5",
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
"clarifai/bigcode.code.StarCoder",
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
"clarifai/mosaicml.mpt.mpt-7b-instruct",
"clarifai/anthropic.completion.claude-3-opus",
"clarifai/anthropic.completion.claude-3-sonnet",
"clarifai/gcp.generate.gemini-1_5-pro",
"clarifai/gcp.generate.imagen-2",
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
]
huggingface_models: List = [ huggingface_models: List = [
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-7b-chat-hf",
@ -505,6 +580,7 @@ provider_list: List = [
"text-completion-openai", "text-completion-openai",
"cohere", "cohere",
"cohere_chat", "cohere_chat",
"clarifai",
"anthropic", "anthropic",
"replicate", "replicate",
"huggingface", "huggingface",
@ -537,6 +613,7 @@ provider_list: List = [
"xinference", "xinference",
"fireworks_ai", "fireworks_ai",
"watsonx", "watsonx",
"triton",
"predibase", "predibase",
"custom", # custom apis "custom", # custom apis
] ]
@ -654,6 +731,7 @@ from .llms.predibase import PredibaseConfig
from .llms.anthropic_text import AnthropicTextConfig from .llms.anthropic_text import AnthropicTextConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere import CohereConfig from .llms.cohere import CohereConfig
from .llms.clarifai import ClarifaiConfig
from .llms.ai21 import AI21Config from .llms.ai21 import AI21Config
from .llms.together_ai import TogetherAIConfig from .llms.together_ai import TogetherAIConfig
from .llms.cloudflare import CloudflareConfig from .llms.cloudflare import CloudflareConfig
@ -668,6 +746,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,
@ -679,7 +758,7 @@ from .llms.bedrock import (
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig, AmazonBedrockGlobalConfig,
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
from .llms.watsonx import IBMWatsonXAIConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore

View file

@ -373,11 +373,12 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
) )
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided. # Set the value with a TTL if it's provided.
if ttl is not None: if ttl is not None:
pipe.setex(cache_key, ttl, json.dumps(cache_value)) pipe.setex(cache_key, ttl, json_cache_value)
else: else:
pipe.set(cache_key, json.dumps(cache_value)) pipe.set(cache_key, json_cache_value)
# Execute the pipeline and return the results. # Execute the pipeline and return the results.
results = await pipe.execute() results = await pipe.execute()
@ -810,9 +811,7 @@ class RedisSemanticCache(BaseCache):
# get the prompt # get the prompt
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# create an embedding for prompt # create an embedding for prompt
embedding_response = litellm.embedding( embedding_response = litellm.embedding(
@ -847,9 +846,7 @@ class RedisSemanticCache(BaseCache):
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# convert to embedding # convert to embedding
embedding_response = litellm.embedding( embedding_response = litellm.embedding(
@ -909,9 +906,7 @@ class RedisSemanticCache(BaseCache):
# get the prompt # get the prompt
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
# create an embedding for prompt # create an embedding for prompt
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
@ -964,9 +959,7 @@ class RedisSemanticCache(BaseCache):
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "" prompt = "".join(message["content"] for message in messages)
for message in messages:
prompt += message["content"]
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
@ -1448,7 +1441,7 @@ class DualCache(BaseCache):
class Cache: class Cache:
def __init__( def __init__(
self, self,
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local", type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
@ -1491,13 +1484,14 @@ class Cache:
redis_semantic_cache_use_async=False, redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None, redis_flush_size=None,
disk_cache_dir=None,
**kwargs, **kwargs,
): ):
""" """
Initializes the cache based on the given type. Initializes the cache based on the given type.
Args: Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local". type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis". host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis".
@ -1543,6 +1537,8 @@ class Cache:
s3_path=s3_path, s3_path=s3_path,
**kwargs, **kwargs,
) )
elif type == "disk":
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback: if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache") litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback: if "cache" not in litellm.success_callback:
@ -1914,8 +1910,86 @@ class Cache:
await self.cache.disconnect() await self.cache.disconnect()
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)
def enable_cache( def enable_cache(
type: Optional[Literal["local", "redis", "s3"]] = "local", type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
@ -1944,7 +2018,7 @@ def enable_cache(
Enable cache with the specified configuration. Enable cache with the specified configuration.
Args: Args:
type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local". type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None. host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None. port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None. password (Optional[str]): The password for the cache server. Defaults to None.
@ -1980,7 +2054,7 @@ def enable_cache(
def update_cache( def update_cache(
type: Optional[Literal["local", "redis"]] = "local", type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
@ -2009,7 +2083,7 @@ def update_cache(
Update the cache for LiteLLM. Update the cache for LiteLLM.
Args: Args:
type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local". type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None. host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None. port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None. password (Optional[str]): The password for the cache. Defaults to None.

View file

@ -9,55 +9,64 @@
## LiteLLM versions of the OpenAI Exception Types ## LiteLLM versions of the OpenAI Exception Types
from openai import ( import openai
AuthenticationError,
BadRequestError,
NotFoundError,
RateLimitError,
APIStatusError,
OpenAIError,
APIError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError,
PermissionDeniedError,
)
import httpx import httpx
from typing import Optional from typing import Optional
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 401 self.status_code = 401
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8 # raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore class NotFoundError(openai.NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 404 self.status_code = 404
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class BadRequestError(BadRequestError): # type: ignore class BadRequestError(openai.BadRequestError): # type: ignore
def __init__( def __init__(
self, message, model, llm_provider, response: Optional[httpx.Response] = None self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
response = response or httpx.Response( response = response or httpx.Response(
status_code=self.status_code, status_code=self.status_code,
request=httpx.Request( request=httpx.Request(
@ -69,19 +78,29 @@ class BadRequestError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 422 self.status_code = 422
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(openai.APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None
):
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__( super().__init__(
request=request request=request
@ -90,29 +109,46 @@ class Timeout(APITimeoutError): # type: ignore
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
# custom function to convert to str # custom function to convert to str
def __str__(self): def __str__(self):
return str(self.message) return str(self.message)
class PermissionDeniedError(PermissionDeniedError): # type:ignore class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 403 self.status_code = 403
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(openai.RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 429 self.status_code = 429
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.modle = model self.modle = model
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -120,11 +156,19 @@ class RateLimitError(RateLimitError): # type: ignore
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
@ -135,11 +179,19 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
class ContentPolicyViolationError(BadRequestError): # type: ignore class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}} # Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
@ -148,51 +200,77 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore class ServiceUnavailableError(openai.APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 503 self.status_code = 503
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore class APIError(openai.APIError): # type: ignore
def __init__( def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request self,
status_code,
message,
llm_provider,
model,
request: httpx.Request,
litellm_debug_info: Optional[str] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__(self.message, request=request, body=None) # type: ignore super().__init__(self.message, request=request, body=None) # type: ignore
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore class APIConnectionError(openai.APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model, request: httpx.Request): def __init__(
self,
message,
llm_provider,
model,
request: httpx.Request,
litellm_debug_info: Optional[str] = None,
):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
self.litellm_debug_info = litellm_debug_info
super().__init__(message=self.message, request=request) super().__init__(message=self.message, request=request)
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None
):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request) response = httpx.Response(status_code=500, request=request)
self.litellm_debug_info = litellm_debug_info
super().__init__(response=response, body=None, message=message) super().__init__(response=response, body=None, message=message)
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(openai.OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
self.status_code = original_exception.http_status self.status_code = original_exception.http_status
super().__init__( super().__init__(
@ -214,7 +292,7 @@ class BudgetExceededError(Exception):
## DEPRECATED ## ## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore class InvalidRequestError(openai.BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -3,7 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -18,8 +16,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union, Optional from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -0,0 +1,179 @@
# What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import dotenv, os, json
import litellm
import traceback, httpx
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
from typing import Optional, Literal
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
LAGO_API_BASE,
LAGO_API_KEY,
LAGO_API_EVENT_CODE,
Optional:
LAGO_API_CHARGE_BY
in the environment
"""
missing_keys = []
if os.getenv("LAGO_API_KEY", None) is None:
missing_keys.append("LAGO_API_KEY")
if os.getenv("LAGO_API_BASE", None) is None:
missing_keys.append("LAGO_API_BASE")
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
missing_keys.append("LAGO_API_EVENT_CODE")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
external_customer_id: Optional[str] = None
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
os.environ["LAGO_API_CHARGE_BY"], str
):
if os.environ["LAGO_API_CHARGE_BY"] in [
"end_user_id",
"user_id",
"team_id",
]:
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
else:
raise Exception("invalid LAGO_API_CHARGE_BY set")
if charge_by == "end_user_id":
external_customer_id = end_user_id
elif charge_by == "team_id":
external_customer_id = team_id
elif charge_by == "user_id":
external_customer_id = user_id
if external_customer_id is None:
raise Exception("External Customer ID is not set")
return {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_customer_id": external_customer_id,
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {"model": model, "response_cost": cost, **usage},
}
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
_url
)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
except Exception as e:
raise e
response: Optional[httpx.Response] = None
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy import copy
import traceback import traceback
from packaging.version import Version from packaging.version import Version
@ -262,7 +260,23 @@ class LangFuseLogger:
try: try:
tags = [] tags = []
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata try:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, dict)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
metadata = new_metadata
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3") supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
@ -307,6 +321,9 @@ class LangFuseLogger:
trace_id = clean_metadata.pop("trace_id", None) trace_id = clean_metadata.pop("trace_id", None)
existing_trace_id = clean_metadata.pop("existing_trace_id", None) existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", []) update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
mask_input = clean_metadata.pop("mask_input", False)
mask_output = clean_metadata.pop("mask_output", False)
if trace_name is None and existing_trace_id is None: if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name # just log `litellm-{call_type}` as the trace name
@ -334,18 +351,19 @@ class LangFuseLogger:
# Special keys that are found in the function arguments and not the metadata # Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys: if "input" in update_trace_keys:
trace_params["input"] = input trace_params["input"] = input if not mask_input else "redacted-by-litellm"
if "output" in update_trace_keys: if "output" in update_trace_keys:
trace_params["output"] = output trace_params["output"] = output if not mask_output else "redacted-by-litellm"
else: # don't overwrite an existing trace else: # don't overwrite an existing trace
trace_params = { trace_params = {
"id": trace_id, "id": trace_id,
"name": trace_name, "name": trace_name,
"session_id": session_id, "session_id": session_id,
"input": input, "input": input if not mask_input else "redacted-by-litellm",
"version": clean_metadata.pop( "version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None) "trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
"user_id": user_id,
} }
for key in list( for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
@ -357,7 +375,14 @@ class LangFuseLogger:
if level == "ERROR": if level == "ERROR":
trace_params["status_message"] = output trace_params["status_message"] = output
else: else:
trace_params["output"] = output trace_params["output"] = output if not mask_output else "redacted-by-litellm"
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:
# log the raw_metadata in the trace
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
else:
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
cost = kwargs.get("response_cost", None) cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}") print_verbose(f"trace: {cost}")
@ -409,7 +434,6 @@ class LangFuseLogger:
"url": url, "url": url,
"headers": clean_headers, "headers": clean_headers,
} }
trace = self.Langfuse.trace(**trace_params) trace = self.Langfuse.trace(**trace_params)
generation_id = None generation_id = None
@ -441,8 +465,8 @@ class LangFuseLogger:
"end_time": end_time, "end_time": end_time,
"model": kwargs["model"], "model": kwargs["model"],
"model_parameters": optional_params, "model_parameters": optional_params,
"input": input, "input": input if not mask_input else "redacted-by-litellm",
"output": output, "output": output if not mask_output else "redacted-by-litellm",
"usage": usage, "usage": usage,
"metadata": clean_metadata, "metadata": clean_metadata,
"level": level, "level": level,
@ -450,7 +474,29 @@ class LangFuseLogger:
} }
if supports_prompt: if supports_prompt:
generation_params["prompt"] = clean_metadata.pop("prompt", None) user_prompt = clean_metadata.pop("prompt", None)
if user_prompt is None:
pass
elif isinstance(user_prompt, dict):
from langfuse.model import (
TextPromptClient,
ChatPromptClient,
Prompt_Text,
Prompt_Chat,
)
if user_prompt.get("type", "") == "chat":
_prompt_chat = Prompt_Chat(**user_prompt)
generation_params["prompt"] = ChatPromptClient(
prompt=_prompt_chat
)
elif user_prompt.get("type", "") == "text":
_prompt_text = Prompt_Text(**user_prompt)
generation_params["prompt"] = TextPromptClient(
prompt=_prompt_text
)
else:
generation_params["prompt"] = user_prompt
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output generation_params["status_message"] = output

View file

@ -3,8 +3,6 @@
import dotenv, os # type: ignore import dotenv, os # type: ignore
import requests # type: ignore import requests # type: ignore
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio import asyncio
import types import types

View file

@ -2,14 +2,10 @@
# On success + failure, log events to lunary.ai # On success + failure, log events to lunary.ai
from datetime import datetime, timezone from datetime import datetime, timezone
import traceback import traceback
import dotenv
import importlib import importlib
import sys
import packaging import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
@ -18,13 +14,33 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_tool_calls(tool_calls):
if tool_calls is None:
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}
return serialized
return [clean_tool_call(tool_call) for tool_call in tool_calls]
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
def clean_message(message): def clean_message(message):
# if is strin, return as is # if is string, return as is
if isinstance(message, str): if isinstance(message, str):
return message return message
@ -38,9 +54,7 @@ def parse_messages(input):
# Only add tool_calls and function_call to res if they are set # Only add tool_calls and function_call to res if they are set
if message.get("tool_calls"): if message.get("tool_calls"):
serialized["tool_calls"] = message.get("tool_calls") serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
if message.get("function_call"):
serialized["function_call"] = message.get("function_call")
return serialized return serialized
@ -62,14 +76,16 @@ class LunaryLogger:
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print( print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
) )
raise ImportError raise ImportError
self.lunary_client = lunary self.lunary_client = lunary
except ImportError: except ImportError:
print("Lunary not installed. Please install it using 'pip install lunary'") print( # noqa
"Lunary not installed. Please install it using 'pip install lunary'"
) # noqa
raise ImportError raise ImportError
def log_event( def log_event(
@ -93,8 +109,13 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
optional_params = kwargs.get("optional_params", {})
metadata = litellm_params.get("metadata", {}) or {} metadata = litellm_params.get("metadata", {}) or {}
if optional_params:
# merge into extra
extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or [] tags = litellm_params.pop("tags", None) or []
if extra: if extra:
@ -104,7 +125,7 @@ class LunaryLogger:
# keep only serializable types # keep only serializable types
for param, value in extra.items(): for param, value in extra.items():
if not isinstance(value, (str, int, bool, float)): if not isinstance(value, (str, int, bool, float)) and param != "tools":
try: try:
extra[param] = str(value) extra[param] = str(value)
except: except:
@ -140,7 +161,7 @@ class LunaryLogger:
metadata=metadata, metadata=metadata,
runtime="litellm", runtime="litellm",
tags=tags, tags=tags,
extra=extra, params=extra,
) )
self.lunary_client.track_event( self.lunary_client.track_event(

View file

@ -3,8 +3,6 @@
import dotenv, os, json import dotenv, os, json
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -4,8 +4,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -5,8 +5,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,9 +1,7 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -2,8 +2,6 @@
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict from typing import List, Literal, Any, Union, Optional, Dict
@ -14,7 +12,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
from enum import Enum from enum import Enum
from datetime import datetime as dt, timedelta from datetime import datetime as dt, timedelta, timezone
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import random import random
@ -33,7 +31,10 @@ class LiteLLMBase(BaseModel):
class SlackAlertingArgs(LiteLLMBase): class SlackAlertingArgs(LiteLLMBase):
daily_report_frequency: int = 12 * 60 * 60 # 12 hours default_daily_report_frequency: int = 12 * 60 * 60 # 12 hours
daily_report_frequency: int = int(
os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency)
)
report_check_interval: int = 5 * 60 # 5 minutes report_check_interval: int = 5 * 60 # 5 minutes
@ -78,8 +79,7 @@ class SlackAlerting(CustomLogger):
internal_usage_cache: Optional[DualCache] = None, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: Optional[ alert_types: List[
List[
Literal[ Literal[
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
@ -88,7 +88,6 @@ class SlackAlerting(CustomLogger):
"db_exceptions", "db_exceptions",
"daily_reports", "daily_reports",
] ]
]
] = [ ] = [
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
@ -242,6 +241,8 @@ class SlackAlerting(CustomLogger):
end_time=end_time, end_time=end_time,
) )
) )
if litellm.turn_off_message_logging:
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
@ -348,8 +349,9 @@ class SlackAlerting(CustomLogger):
all_none = True all_none = True
for val in combined_metrics_values: for val in combined_metrics_values:
if val is not None: if val is not None and val > 0:
all_none = False all_none = False
break
if all_none: if all_none:
return False return False
@ -367,12 +369,15 @@ class SlackAlerting(CustomLogger):
for value in failed_request_values for value in failed_request_values
] ]
## Get the indices of top 5 keys with the highest numerical values (ignoring None values) ## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values)
top_5_failed = sorted( top_5_failed = sorted(
range(len(replaced_failed_values)), range(len(replaced_failed_values)),
key=lambda i: replaced_failed_values[i], key=lambda i: replaced_failed_values[i],
reverse=True, reverse=True,
)[:5] )[:5]
top_5_failed = [
index for index in top_5_failed if replaced_failed_values[index] > 0
]
# find top 5 slowest # find top 5 slowest
# Replace None values with a placeholder value (-1 in this case) # Replace None values with a placeholder value (-1 in this case)
@ -382,17 +387,22 @@ class SlackAlerting(CustomLogger):
for value in latency_values for value in latency_values
] ]
# Get the indices of top 5 values with the highest numerical values (ignoring None values) # Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values)
top_5_slowest = sorted( top_5_slowest = sorted(
range(len(replaced_slowest_values)), range(len(replaced_slowest_values)),
key=lambda i: replaced_slowest_values[i], key=lambda i: replaced_slowest_values[i],
reverse=True, reverse=True,
)[:5] )[:5]
top_5_slowest = [
index for index in top_5_slowest if replaced_slowest_values[index] > 0
]
# format alert -> return the litellm model name + api base # format alert -> return the litellm model name + api base
message = f"\n\nHere are today's key metrics 📈: \n\n" message = f"\n\nHere are today's key metrics 📈: \n\n"
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n" message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
if not top_5_failed:
message += "\tNone\n"
for i in range(len(top_5_failed)): for i in range(len(top_5_failed)):
key = failed_request_keys[top_5_failed[i]].split(":")[0] key = failed_request_keys[top_5_failed[i]].split(":")[0]
_deployment = router.get_model_info(key) _deployment = router.get_model_info(key)
@ -412,7 +422,9 @@ class SlackAlerting(CustomLogger):
value = replaced_failed_values[top_5_failed[i]] value = replaced_failed_values[top_5_failed[i]]
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n" message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n" message += "\n\n*😅 Top Slowest Deployments:*\n\n"
if not top_5_slowest:
message += "\tNone\n"
for i in range(len(top_5_slowest)): for i in range(len(top_5_slowest)):
key = latency_keys[top_5_slowest[i]].split(":")[0] key = latency_keys[top_5_slowest[i]].split(":")[0]
_deployment = router.get_model_info(key) _deployment = router.get_model_info(key)
@ -464,6 +476,11 @@ class SlackAlerting(CustomLogger):
messages = messages[:100] messages = messages[:100]
except: except:
messages = "" messages = ""
if litellm.turn_off_message_logging:
messages = (
"Message not logged. `litellm.turn_off_message_logging=True`."
)
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
else: else:
request_info = "" request_info = ""
@ -814,14 +831,6 @@ Model Info:
updated_at=litellm.utils.get_utc_datetime(), updated_at=litellm.utils.get_utc_datetime(),
) )
) )
if "llm_exceptions" in self.alert_types:
original_exception = kwargs.get("exception", None)
await self.send_alert(
message="LLM API Failure - " + str(original_exception),
level="High",
alert_type="llm_exceptions",
)
async def _run_scheduler_helper(self, llm_router) -> bool: async def _run_scheduler_helper(self, llm_router) -> bool:
""" """
@ -844,15 +853,22 @@ Model Info:
value=_current_time, value=_current_time,
) )
else: else:
# check if current time - interval >= time last sent # Check if current time - interval >= time last sent
delta = current_time - timedelta( delta_naive = timedelta(seconds=self.alerting_args.daily_report_frequency)
seconds=self.alerting_args.daily_report_frequency
)
if isinstance(report_sent, str): if isinstance(report_sent, str):
report_sent = dt.fromisoformat(report_sent) report_sent = dt.fromisoformat(report_sent)
if delta >= report_sent: # Ensure report_sent is an aware datetime object
if report_sent.tzinfo is None:
report_sent = report_sent.replace(tzinfo=timezone.utc)
# Calculate delta as an aware datetime object with the same timezone as report_sent
delta = report_sent - delta_naive
current_time_utc = current_time.astimezone(timezone.utc)
delta_utc = delta.astimezone(timezone.utc)
if current_time_utc >= delta_utc:
# Sneak in the reporting logic here # Sneak in the reporting logic here
await self.send_daily_reports(router=llm_router) await self.send_daily_reports(router=llm_router)
# Also, don't forget to update the report_sent time after sending the report! # Also, don't forget to update the report_sent time after sending the report!
@ -885,3 +901,99 @@ Model Info:
) # shuffle to prevent collisions ) # shuffle to prevent collisions
await asyncio.sleep(interval) await asyncio.sleep(interval)
return return
async def send_weekly_spend_report(self):
""" """
try:
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
todays_date = datetime.datetime.now().date()
week_before = todays_date - datetime.timedelta(days=7)
weekly_spend_per_team, weekly_spend_per_tag = (
await _get_spend_report_for_time_range(
start_date=week_before.strftime("%Y-%m-%d"),
end_date=todays_date.strftime("%Y-%m-%d"),
)
)
_weekly_spend_message = f"*💸 Weekly Spend Report for `{week_before.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` *\n"
if weekly_spend_per_team is not None:
_weekly_spend_message += "\n*Team Spend Report:*\n"
for spend in weekly_spend_per_team:
_team_spend = spend["total_spend"]
_team_spend = float(_team_spend)
# round to 4 decimal places
_team_spend = round(_team_spend, 4)
_weekly_spend_message += (
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
)
if weekly_spend_per_tag is not None:
_weekly_spend_message += "\n*Tag Spend Report:*\n"
for spend in weekly_spend_per_tag:
_tag_spend = spend["total_spend"]
_tag_spend = float(_tag_spend)
# round to 4 decimal places
_tag_spend = round(_tag_spend, 4)
_weekly_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
await self.send_alert(
message=_weekly_spend_message,
level="Low",
alert_type="daily_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)
async def send_monthly_spend_report(self):
""" """
try:
from calendar import monthrange
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
todays_date = datetime.datetime.now().date()
first_day_of_month = todays_date.replace(day=1)
_, last_day_of_month = monthrange(todays_date.year, todays_date.month)
last_day_of_month = first_day_of_month + datetime.timedelta(
days=last_day_of_month - 1
)
monthly_spend_per_team, monthly_spend_per_tag = (
await _get_spend_report_for_time_range(
start_date=first_day_of_month.strftime("%Y-%m-%d"),
end_date=last_day_of_month.strftime("%Y-%m-%d"),
)
)
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
if monthly_spend_per_team is not None:
_spend_message += "\n*Team Spend Report:*\n"
for spend in monthly_spend_per_team:
_team_spend = spend["total_spend"]
_team_spend = float(_team_spend)
# round to 4 decimal places
_team_spend = round(_team_spend, 4)
_spend_message += (
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
)
if monthly_spend_per_tag is not None:
_spend_message += "\n*Tag Spend Report:*\n"
for spend in monthly_spend_per_tag:
_tag_spend = spend["total_spend"]
_tag_spend = float(_tag_spend)
# round to 4 decimal places
_tag_spend = round(_tag_spend, 4)
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
await self.send_alert(
message=_spend_message,
level="Low",
alert_type="daily_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm import litellm

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: def __getitem__(self, key: K) -> V: ... # noqa
... # pragma: no cover
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: def get( # noqa
... # pragma: no cover self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(
@ -173,12 +173,11 @@ except:
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
import requests import requests
import requests import requests
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_streaming_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
"""
Return stream object for tool-calling + streaming
"""
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response( def process_response(
self, self,
model, model: str,
response, response: Union[requests.Response, httpx.Response],
model_response, model_response: ModelResponse,
_is_function_call, stream: bool,
stream, logging_obj: litellm.utils.Logging,
logging_obj, optional_params: dict,
api_key, api_key: str,
data, data: Union[dict, str],
messages, messages: List,
print_verbose, print_verbose,
): encoding,
) -> ModelResponse:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage) # type: ignore
return model_response return model_response
async def acompletion_stream_function( async def acompletion_stream_function(
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
): ) -> Union[ModelResponse, CustomStreamWrapper]:
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
response = await self.async_handler.post( response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
return self.process_response( if stream and _is_function_call:
return self.process_streaming_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def completion( def completion(
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
return self.process_response(
if stream and _is_function_call:
return self.process_streaming_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def embedding(self): def embedding(self):

View file

@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_response( def _process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str self, model_response: ModelResponse, response, encoding, prompt: str, model: str
): ):
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
response = self.process_response( response = self._process_response(
model_response=model_response, model_response=model_response,
response=response, response=response,
encoding=encoding, encoding=encoding,
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
response = self.process_response( response = self._process_response(
model_response=model_response, model_response=model_response,
response=response, response=response,
encoding=encoding, encoding=encoding,

View file

@ -8,14 +8,16 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse, TranscriptionResponse,
get_secret,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx # type: ignore import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid import uuid
import os
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -105,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
optional_params["azure_ad_token"] = value optional_params["azure_ad_token"] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return ["europe", "sweden", "switzerland", "france", "uk"]
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {
@ -126,6 +134,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
if azure_client_id is None or azure_tenant is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -137,6 +190,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
headers["api-key"] = api_key headers["api-key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
@ -189,6 +244,9 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True: if acompletion is True:
@ -276,6 +334,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -351,6 +411,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client # setting Azure client
@ -422,6 +484,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -478,6 +542,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
@ -599,6 +665,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
## LOGGING ## LOGGING
@ -755,6 +823,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True: if aimg_generation == True:
@ -833,6 +903,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None: if max_retries is not None:

View file

@ -1,12 +1,32 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import httpx import httpx, requests
from typing import Optional from typing import Optional, Union
from litellm.utils import Logging
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
"""
return [
"eu-west-1",
"eu-west-3",
"eu-central-1",
]
class AmazonTitanConfig: class AmazonTitanConfig:
""" """
@ -551,6 +561,7 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
@ -567,6 +578,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
] ]
# Iterate over parameters and update if needed # Iterate over parameters and update if needed
@ -582,6 +594,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
) = params_to_check ) = params_to_check
### SET REGION NAME ### SET REGION NAME
@ -620,7 +633,38 @@ def init_bedrock_client(
config = boto3.session.Config() config = boto3.session.Config()
### CHECK STS ### ### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None: if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in # use sts if role name passed in
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -755,6 +799,7 @@ def completion(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop("aws_bedrock_client", None) client = optional_params.pop("aws_bedrock_client", None)
@ -769,6 +814,7 @@ def completion(
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name, aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
extra_headers=extra_headers, extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -1291,6 +1337,7 @@ def embedding(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1298,6 +1345,7 @@ def embedding(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
) )
@ -1380,6 +1428,7 @@ def image_generation(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1387,6 +1436,7 @@ def image_generation(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
timeout=timeout, timeout=timeout,

View file

@ -0,0 +1,733 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import (
Callable,
Optional,
List,
Literal,
Union,
Any,
TypedDict,
Tuple,
Iterator,
AsyncIterator,
)
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt
from litellm.types.llms.bedrock import *
class AmazonCohereChatConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
"""
documents: Optional[List[Document]] = None
search_queries_only: Optional[bool] = None
preamble: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
p: Optional[float] = None
k: Optional[float] = None
prompt_truncation: Optional[str] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
return_prompt: Optional[bool] = None
stop_sequences: Optional[List[str]] = None
raw_prompting: Optional[bool] = None
def __init__(
self,
documents: Optional[List[Document]] = None,
search_queries_only: Optional[bool] = None,
preamble: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
p: Optional[float] = None,
k: Optional[float] = None,
prompt_truncation: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
return_prompt: Optional[bool] = None,
stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return [
"max_tokens",
"stream",
"stop",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"seed",
"stop",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if "seed":
optional_params["seed"] = value
return optional_params
class BedrockLLM(BaseLLM):
"""
Example call
```
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
--header 'Content-Type: application/json' \
--header 'Accept: application/json' \
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
--data-raw '{
"prompt": "Hi",
"temperature": 0,
"p": 0.9,
"max_tokens": 4096
}'
```
"""
def __init__(self) -> None:
super().__init__()
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
) = params_to_check
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
return sts_response["Credentials"]
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if stream is not None and stream == True:
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
data = json.dumps(_data)
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream == True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
if stream is not None and stream == True:
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
stream=stream,
)
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=response.text
)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
loader = Loader()
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")
class AWSEventStreamDecoder:
def __init__(self) -> None:
from botocore.parsers import EventStreamJSONParser
self.parser = EventStreamJSONParser()
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GenericStreamingChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]

328
litellm/llms/clarifai.py Normal file
View file

@ -0,0 +1,328 @@
import os, types, traceback
import json
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, Choices, Message, CustomStreamWrapper
import litellm
import httpx
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .prompt_templates.factory import prompt_factory, custom_prompt
class ClarifaiError(Exception):
def __init__(self, status_code, message, url):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
)
class ClarifaiConfig:
"""
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
TODO fill in the details
"""
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completions_to_model(payload):
# if payload["n"] != 1:
# raise HTTPException(
# status_code=422,
# detail="Only one generation is supported. Please set candidate_count to 1.",
# )
params = {}
if temperature := payload.get("temperature"):
params["temperature"] = temperature
if max_tokens := payload.get("max_tokens"):
params["max_tokens"] = max_tokens
return {
"inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
"model": {"output_info": {"params": params}},
}
def process_response(
model,
prompt,
response,
model_response,
api_key,
data,
encoding,
logging_obj
):
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
)
# print(completion_response)
try:
choices_list = []
for idx, item in enumerate(completion_response["outputs"]):
if len(item["data"]["text"]["raw"]) > 0:
message_obj = Message(content=item["data"]["text"]["raw"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
index=idx + 1, #check
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
# Calculate Usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response["model"] = model
model_response["usage"] = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
def get_prompt_model_name(url: str):
clarifai_model_name = url.split("/")[-2]
if "claude" in clarifai_model_name:
return "anthropic", clarifai_model_name.replace("_", ".")
if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
return "", "meta-llama/llama-2-chat"
else:
return "", clarifai_model_name
async def async_completion(
model: str,
prompt: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={}):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj,
)
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
custom_prompt_dict={},
acompletion=False,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
model = convert_model_to_url(model, api_base)
prompt = " ".join(message["content"] for message in messages) # TODO
## Load Config
config = litellm.ClarifaiConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
):
optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
if custom_llm_provider == "anthropic":
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider="clarifai"
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider=custom_llm_provider
)
# print(prompt); exit(0)
data = {
"prompt": prompt,
**optional_params,
}
data = completions_to_model(data)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
if acompletion==True:
return async_completion(
model=model,
prompt=prompt,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
data=data,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
response = requests.post(
model,
headers=headers,
data=json.dumps(data),
)
# print(response.content); exit()
if response.status_code != 200:
raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
if "stream" in optional_params and optional_params["stream"] == True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
)
return stream_response
else:
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj)
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response

View file

@ -58,8 +58,15 @@ class AsyncHTTPHandler:
class HTTPHandler: class HTTPHandler:
def __init__( def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 self,
timeout: Optional[httpx.Timeout] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
): ):
if timeout is None:
timeout = _DEFAULT_TIMEOUT
if client is None:
# Create a client with a connection pool # Create a client with a connection pool
self.client = httpx.Client( self.client = httpx.Client(
timeout=timeout, timeout=timeout,
@ -68,6 +75,8 @@ class HTTPHandler:
max_keepalive_connections=concurrent_limit, max_keepalive_connections=concurrent_limit,
), ),
) )
else:
self.client = client
def close(self): def close(self):
# Close the client when you're done with it # Close the client when you're done with it
@ -82,11 +91,15 @@ class HTTPHandler:
def post( def post(
self, self,
url: str, url: str,
data: Optional[dict] = None, data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False,
): ):
response = self.client.post(url, data=data, params=params, headers=headers) req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response return response
def __del__(self) -> None: def __del__(self) -> None:

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -162,18 +227,21 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference", model
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational", model
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return "text-generation", model
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference", model # default to tgi
class Huggingface(BaseLLM): class Huggingface(BaseLLM):
@ -202,7 +270,7 @@ class Huggingface(BaseLLM):
self, self,
completion_response, completion_response,
model_response, model_response,
task, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
@ -270,6 +338,10 @@ class Huggingface(BaseLLM):
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response["choices"][0]["message"]["content"] = output_parser(
@ -332,7 +404,13 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task, model = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = "" input_text = ""
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, }
"stream": ( # type: ignore if task == "text-generation-inference":
True data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False
), )
}
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict) isinstance(completion_response, dict)
and "error" in completion_response and "error" in completion_response
): ):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
message=completion_response["error"], message=completion_response["error"], # type: ignore
status_code=response.status_code, status_code=response.status_code,
) )
return self.convert_to_model_response_object( return self.convert_to_model_response_object(
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
task: str, task: hf_tasks,
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,

View file

@ -300,7 +300,7 @@ def get_ollama_response(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
@ -484,7 +484,7 @@ async def ollama_acompletion(
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"] model_response["model"] = "ollama_chat/" + data["model"]

View file

@ -53,6 +53,113 @@ class OpenAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
return optional_params
class OpenAIConfig: class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -1327,8 +1434,8 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data thread_id, **message_data # type: ignore
) )
response_obj: Optional[OpenAIMessage] = None response_obj: Optional[OpenAIMessage] = None
@ -1458,7 +1565,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
response = openai_client.beta.threads.runs.create_and_poll( response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,

View file

@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj: litellm.utils.Logging, logging_obj: litellm.utils.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: dict, data: Union[dict, str],
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
try: try:
completion_response = response.json() completion_response = response.json()
except: except:
raise PredibaseError( raise PredibaseError(message=response.text, status_code=422)
message=response.text, status_code=response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise PredibaseError( raise PredibaseError(
message=str(completion_response["error"]), message=str(completion_response["error"]),
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True: if acompletion == True:
### ASYNC STREAMING ### ASYNC STREAMING
if stream == True: if stream == True:
return self.async_streaming( return self.async_streaming(

View file

@ -1509,6 +1509,11 @@ def prompt_factory(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=messages, messages=messages,
) )
elif custom_llm_provider == "clarifai":
if "claude" in model:
return anthropic_pt(messages=messages)
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "perplexity":
for message in messages: for message in messages:
message.pop("name", None) message.pop("name", None)

119
litellm/llms/triton.py Normal file
View file

@ -0,0 +1,119 @@
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TritonError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
async def aembedding(
self,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj=None,
api_key: Optional[str] = None,
):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_outputs = _json_response["outputs"]
_output_data = _outputs[0]["data"]
_embedding_output = {
"object": "embedding",
"index": 0,
"embedding": _output_data,
}
model_response.model = _json_response.get("model_name", "None")
model_response.data = [_embedding_output]
return model_response
def embedding(
self,
model: str,
input: list,
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
):
data_for_triton = {
"inputs": [
{
"name": "input_text",
"shape": [1],
"datatype": "BYTES",
"data": input,
}
]
}
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call(
input="",
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": curl_string,
},
)
if aembedding == True:
response = self.aembedding(
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
return response
else:
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)

View file

@ -198,6 +198,23 @@ class VertexAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
"""
return [
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
]
import asyncio import asyncio
@ -419,6 +436,7 @@ def completion(
from google.protobuf.struct_pb2 import Value # type: ignore from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore import google.auth # type: ignore
import proto # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -605,9 +623,21 @@ def completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict) args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
@ -810,6 +840,8 @@ def completion(
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
except Exception as e: except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -835,6 +867,8 @@ async def async_completion(
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
try: try:
import proto # type: ignore
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
@ -869,9 +903,21 @@ async def async_completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict) args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[

View file

@ -1,12 +1,26 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import (
Callable,
Dict,
Generator,
AsyncGenerator,
Iterator,
AsyncIterator,
Optional,
Any,
Union,
List,
ContextManager,
AsyncContextManager,
)
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -149,6 +163,15 @@ class IBMWatsonXAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"eu-de",
"eu-gb",
]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
@ -188,11 +211,12 @@ class WatsonXAIEndpoint(str, Enum):
) )
EMBEDDINGS = "/ml/v1/text/embeddings" EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts" PROMPTS = "/ml/v1/prompts"
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM Watsonx.ai API for text generation and embeddings. Class to interface with IBM watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai Reference: https://cloud.ibm.com/apidocs/watsonx-ai
""" """
@ -343,7 +367,7 @@ class IBMWatsonXAI(BaseLLM):
) )
if token is None and api_key is not None: if token is None and api_key is not None:
# generate the auth token # generate the auth token
if print_verbose: if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai") print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key) token = self.generate_iam_token(api_key)
elif token is None and api_key is None: elif token is None and api_key is None:
@ -378,10 +402,11 @@ class IBMWatsonXAI(BaseLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict, optional_params=None,
litellm_params: Optional[dict] = None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
timeout: Optional[float] = None, timeout=None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -402,12 +427,12 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
def process_text_request(request_params: dict) -> ModelResponse: def process_text_gen_response(json_resp: dict) -> ModelResponse:
with self._manage_response( if "results" not in json_resp:
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout raise WatsonXAIError(
) as resp: status_code=500,
json_resp = resp.json() message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"] generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"] prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -415,36 +440,70 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
setattr( usage = Usage(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
setattr(model_response, "usage", usage)
return model_response return model_response
def process_stream_request( def process_stream_response(
request_params: dict, stream_resp: Union[Iterator[str], AsyncIterator],
) -> litellm.CustomStreamWrapper: ) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled streamwrapper = litellm.CustomStreamWrapper(
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream stream_resp,
with self._manage_response(
request_params,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
response = litellm.CustomStreamWrapper(
resp.iter_lines(),
model=model, model=model,
custom_llm_provider="watsonx", custom_llm_provider="watsonx",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return response return streamwrapper
# create the function to manage the request to watsonx.ai
self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse:
with self.request_manager.request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with self.request_manager.async_request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self.request_manager.request(
request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = process_stream_response(resp.iter_lines())
return streamwrapper
async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with self.request_manager.async_request(
request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = process_stream_response(resp.aiter_lines())
return streamwrapper
try: try:
## Get the response from the model ## Get the response from the model
@ -455,10 +514,18 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream: if stream and (acompletion is True):
return process_stream_request(req_params) # stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif (acompletion is True):
# async text generation
return handle_text_request_async(req_params)
else: else:
return process_text_request(req_params) # regular text generation
return handle_text_request(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -473,6 +540,7 @@ class IBMWatsonXAI(BaseLLM):
model_response=None, model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
aembedding=None,
): ):
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
@ -507,9 +575,6 @@ class IBMWatsonXAI(BaseLLM):
} }
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = { req_params = {
"method": "POST", "method": "POST",
"url": url, "url": url,
@ -517,26 +582,50 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
with self._manage_response( request_manager = RequestManager(logging_obj)
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", []) results = json_resp.get("results", [])
embedding_response = [] embedding_response = []
for idx, result in enumerate(results): for idx, result in enumerate(results):
embedding_response.append( embedding_response.append(
{"object": "embedding", "index": idx, "embedding": result["embedding"]} {
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0) input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage( model_response.usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
) )
return model_response return model_response
def handle_embedding(request_params: dict) -> ModelResponse:
with request_manager.request(request_params, input=input) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
async def handle_aembedding(request_params: dict) -> ModelResponse:
async with request_manager.async_request(request_params, input=input) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding is True:
return handle_embedding(req_params)
else:
return handle_aembedding(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params): def generate_iam_token(self, api_key=None, **params):
headers = {} headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
@ -558,52 +647,144 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token self.token = iam_access_token
return iam_access_token return iam_access_token
@contextmanager def get_available_models(self, *, ids_only: bool = True, **params):
def _manage_response( api_params = self._get_api_params(params)
headers = {
"Authorization": f"Bearer {api_params['token']}",
"Content-Type": "application/json",
"Accept": "application/json",
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
with RequestManager(logging_obj=None).request(req_params) as resp:
json_resp = resp.json()
if not ids_only:
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
class RequestManager:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
async with request_manager.request(request_params) as resp:
...
# or
with request_manager.async_request(request_params) as resp:
...
```
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self, self,
request_params: dict, request_params: dict,
logging_obj: Any,
stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,
timeout: Optional[float] = None,
): ):
if self.logging_obj is None:
return
request_str = ( request_str = (
f"response = {request_params['method']}(\n" f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n" f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n" f"\tjson={request_params.get('json')},\n"
f")" f")"
) )
logging_obj.pre_call( self.logging_obj.pre_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
additional_args={ additional_args={
"complete_input_dict": request_params["json"], "complete_input_dict": request_params.get("json"),
"request_str": request_str, "request_str": request_str,
}, },
) )
if timeout:
request_params["timeout"] = timeout def post_call(self, resp, request_params):
try: if self.logging_obj is None:
if stream: return
resp = requests.request( self.logging_obj.post_call(
**request_params,
stream=True,
)
resp.raise_for_status()
yield resp
else:
resp = requests.request(**request_params)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
logging_obj.post_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request_params["json"], "complete_input_dict": request_params.get(
"data", request_params.get("json")
),
}, },
) )
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)

View file

@ -9,12 +9,13 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm
import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
@ -47,6 +48,7 @@ from .llms import (
ai21, ai21,
sagemaker, sagemaker,
bedrock, bedrock,
triton,
huggingface_restapi, huggingface_restapi,
replicate, replicate,
aleph_alpha, aleph_alpha,
@ -56,6 +58,7 @@ from .llms import (
ollama, ollama,
ollama_chat, ollama_chat,
cloudflare, cloudflare,
clarifai,
cohere, cohere,
cohere_chat, cohere_chat,
petals, petals,
@ -75,6 +78,8 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -103,7 +108,6 @@ from litellm.utils import (
) )
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
@ -112,6 +116,8 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -254,7 +260,7 @@ async def acompletion(
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
custom_llm_provider = None custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs # Adjusted to use explicit arguments instead of *args and **kwargs
completion_kwargs = { completion_kwargs = {
"model": model, "model": model,
@ -286,6 +292,7 @@ async def acompletion(
"model_list": model_list, "model_list": model_list,
"acompletion": True, # assuming this is a required parameter "acompletion": True, # assuming this is a required parameter
} }
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None) model=model, api_base=completion_kwargs.get("base_url", None)
) )
@ -297,9 +304,6 @@ async def acompletion(
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
@ -321,6 +325,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model)
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -661,6 +666,7 @@ def completion(
"supports_system_message", "supports_system_message",
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
@ -668,20 +674,6 @@ def completion(
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try: try:
if base_url is not None: if base_url is not None:
api_base = base_url api_base = base_url
@ -721,9 +713,18 @@ def completion(
"aws_region_name", None "aws_region_name", None
) # support region-based pricing for bedrock ) # support region-based pricing for bedrock
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
custom_llm_provider
):
timeout = timeout.read or 600 # default 10 min timeout
elif not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None: if input_cost_per_token is not None and output_cost_per_token is not None:
print_verbose(f"Registering model={model} in model cost map")
litellm.register_model( litellm.register_model(
{ {
f"{custom_llm_provider}/{model}": { f"{custom_llm_provider}/{model}": {
@ -845,6 +846,10 @@ def completion(
proxy_server_request=proxy_server_request, proxy_server_request=proxy_server_request,
preset_cache_key=preset_cache_key, preset_cache_key=preset_cache_key,
no_log=no_log, no_log=no_log,
input_cost_per_second=input_cost_per_second,
input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -1210,6 +1215,61 @@ def completion(
) )
response = model_response response = model_response
elif (
"clarifai" in model
or custom_llm_provider == "clarifai"
or model in litellm.clarifai_models
):
clarifai_key = None
clarifai_key = (
api_key
or litellm.clarifai_key
or litellm.api_key
or get_secret("CLARIFAI_API_KEY")
or get_secret("CLARIFAI_API_TOKEN")
)
api_base = (
api_base
or litellm.api_base
or get_secret("CLARIFAI_API_BASE")
or "https://api.clarifai.com/v2"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = clarifai.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
acompletion=acompletion,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=clarifai_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=clarifai_key,
original_response=model_response,
)
response = model_response
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
api_key = ( api_key = (
@ -1919,6 +1979,24 @@ def completion(
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "cohere" in model:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
)
else:
response = bedrock.completion( response = bedrock.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -2622,6 +2700,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "voyage" or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter" or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
@ -2779,6 +2858,7 @@ def embedding(
"no-log", "no-log",
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -2955,6 +3035,23 @@ def embedding(
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "triton":
if api_base is None:
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
@ -3662,6 +3759,7 @@ def image_generation(
"cache", "cache",
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {

View file

@ -9,6 +9,30 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true
}, },
"gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-turbo-preview": { "gpt-4-turbo-preview": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -1086,6 +1110,36 @@
"supports_tool_choice": true, "supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-1.5-flash-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0215": { "gemini-1.5-pro-preview-0215": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
@ -1331,6 +1385,24 @@
"mode": "completion", "mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-pro": { "gemini/gemini-pro": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 32760, "max_input_tokens": 32760,
@ -1571,6 +1643,159 @@
"litellm_provider": "replicate", "litellm_provider": "replicate",
"mode": "chat" "mode": "chat"
}, },
"openrouter/microsoft/wizardlm-2-8x22b:nitro": {
"max_tokens": 65536,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-1.5": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.0000075,
"input_cost_per_image": 0.00265,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mixtral-8x22b-instruct": {
"max_tokens": 65536,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000065,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cohere/command-r-plus": {
"max_tokens": 128000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/databricks/dbrx-instruct": {
"max_tokens": 32768,
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/anthropic/claude-3-haiku": {
"max_tokens": 200000,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"input_cost_per_image": 0.0004,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/anthropic/claude-3-sonnet": {
"max_tokens": 200000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"input_cost_per_image": 0.0048,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mistral-large": {
"max_tokens": 32000,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cognitivecomputations/dolphin-mixtral-8x7b": {
"max_tokens": 32769,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-vision": {
"max_tokens": 45875,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000375,
"input_cost_per_image": 0.0025,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/fireworks/firellava-13b": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:extended": {
"max_tokens": 16384,
"input_cost_per_token": 0.000000225,
"output_cost_per_token": 0.00000225,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct:nitro": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/openai/gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4-vision-preview": {
"max_tokens": 130000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"input_cost_per_image": 0.01445,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-3.5-turbo": { "openrouter/openai/gpt-3.5-turbo": {
"max_tokens": 4095, "max_tokens": 4095,
"input_cost_per_token": 0.0000015, "input_cost_per_token": 0.0000015,
@ -1621,14 +1846,14 @@
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"openrouter/google/palm-2-chat-bison": { "openrouter/google/palm-2-chat-bison": {
"max_tokens": 8000, "max_tokens": 25804,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/google/palm-2-codechat-bison": { "openrouter/google/palm-2-codechat-bison": {
"max_tokens": 8000, "max_tokens": 20070,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
@ -1711,13 +1936,6 @@
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "openrouter",
"mode": "chat"
},
"j2-ultra": { "j2-ultra": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,
@ -2522,6 +2740,24 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"cohere.command-r-plus-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000030,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"cohere.command-r-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"cohere.embed-english-v3": { "cohere.embed-english-v3": {
"max_tokens": 512, "max_tokens": 512,
"max_input_tokens": 512, "max_input_tokens": 512,
@ -2749,6 +2985,24 @@
"litellm_provider": "ollama", "litellm_provider": "ollama",
"mode": "completion" "mode": "completion"
}, },
"ollama/llama3": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/llama3:70b": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral": { "ollama/mistral": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,
@ -2758,6 +3012,42 @@
"litellm_provider": "ollama", "litellm_provider": "ollama",
"mode": "completion" "mode": "completion"
}, },
"ollama/mistral-7B-Instruct-v0.1": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral-7B-Instruct-v0.2": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x22B-Instruct-v0.1": {
"max_tokens": 65536,
"max_input_tokens": 65536,
"max_output_tokens": 65536,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/codellama": { "ollama/codellama": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/a1602eb39f799143.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}(); !function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/f04e46b02318b660.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/a1602eb39f799143.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[25539,[\"936\",\"static/chunks/2f6dbc85-17d29013b8ff3da5.js\",\"566\",\"static/chunks/566-ccd699ab19124658.js\",\"931\",\"static/chunks/app/page-c804e862b63be987.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/a1602eb39f799143.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"K8KXTbmuI2ArWjjdMi2iq\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html> <!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-c35c14c9afd091ec.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"2ASoJGxS-D4w-vat00xMy\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""] 2:I[77831,[],""]
3:I[25539,["936","static/chunks/2f6dbc85-17d29013b8ff3da5.js","566","static/chunks/566-ccd699ab19124658.js","931","static/chunks/app/page-c804e862b63be987.js"],""] 3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-c35c14c9afd091ec.js"],""]
4:I[5613,[],""] 4:I[5613,[],""]
5:I[31778,[],""] 5:I[31778,[],""]
0:["K8KXTbmuI2ArWjjdMi2iq",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/a1602eb39f799143.css","precedence":"next","crossOrigin":""}]],"$L6"]]]] 0:["2ASoJGxS-D4w-vat00xMy",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]] 6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null 1:null

View file

@ -16,4 +16,3 @@ router_settings:
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

View file

@ -52,8 +52,18 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
class LiteLLMRoutes(enum.Enum): class LiteLLMRoutes(enum.Enum):
openai_route_names: List = [
"chat_completion",
"completion",
"embeddings",
"image_generation",
"audio_transcriptions",
"moderations",
"model_list", # OpenAI /v1/models route
]
openai_routes: List = [ openai_routes: List = [
# chat completions # chat completions
"/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions", "/openai/deployments/{model}/chat/completions",
"/chat/completions", "/chat/completions",
"/v1/chat/completions", "/v1/chat/completions",
@ -79,15 +89,23 @@ class LiteLLMRoutes(enum.Enum):
"/v1/models", "/v1/models",
] ]
llm_utils_routes: List = ["utils/token_counter"]
info_routes: List = [ info_routes: List = [
"/key/info", "/key/info",
"/team/info", "/team/info",
"/team/list",
"/user/info", "/user/info",
"/model/info", "/model/info",
"/v2/model/info", "/v2/model/info",
"/v2/key/info", "/v2/key/info",
] ]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = [
"/global/spend/reset",
]
sso_only_routes: List = [ sso_only_routes: List = [
"/key/generate", "/key/generate",
"/key/update", "/key/update",
@ -110,6 +128,7 @@ class LiteLLMRoutes(enum.Enum):
"/team/new", "/team/new",
"/team/update", "/team/update",
"/team/delete", "/team/delete",
"/team/list",
"/team/info", "/team/info",
"/team/block", "/team/block",
"/team/unblock", "/team/unblock",
@ -182,15 +201,32 @@ class LiteLLM_JWTAuth(LiteLLMBase):
admin_jwt_scope: str = "litellm_proxy_admin" admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[ admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"] Literal[
] = ["management_routes"] "openai_routes",
team_jwt_scope: str = "litellm_team" "info_routes",
team_id_jwt_field: str = "client_id" "management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
]
] = [
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
"info_routes",
]
team_id_jwt_field: Optional[str] = None
team_allowed_routes: List[ team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"] Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"] ] = ["openai_routes", "info_routes"]
team_id_default: Optional[str] = Field(
default=None,
description="If no team_id given, default permissions/spend-tracking to this team.s",
)
org_id_jwt_field: Optional[str] = None org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
end_user_id_jwt_field: Optional[str] = None end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600 public_key_ttl: float = 600
@ -678,6 +714,25 @@ class DynamoDBArgs(LiteLLMBase):
assume_role_aws_session_name: Optional[str] = None assume_role_aws_session_name: Optional[str] = None
class ConfigFieldUpdate(LiteLLMBase):
field_name: str
field_value: Any
config_type: Literal["general_settings"]
class ConfigFieldDelete(LiteLLMBase):
config_type: Literal["general_settings"]
field_name: str
class ConfigList(LiteLLMBase):
field_name: str
field_type: str
field_description: str
field_value: Any
stored_in_db: Optional[bool]
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
@ -725,7 +780,11 @@ class ConfigGeneralSettings(LiteLLMBase):
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth", description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
) )
max_parallel_requests: Optional[int] = Field( max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key" None,
description="maximum parallel requests for each api key",
)
global_max_parallel_requests: Optional[int] = Field(
None, description="global max parallel requests to allow for a proxy instance."
) )
infer_model_from_keys: Optional[bool] = Field( infer_model_from_keys: Optional[bool] = Field(
None, None,
@ -954,3 +1013,16 @@ class LiteLLM_ErrorLogs(LiteLLMBase):
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase): class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None
class TokenCountRequest(LiteLLMBase):
model: str
prompt: Optional[str] = None
messages: Optional[List[dict]] = None
class TokenCountResponse(LiteLLMBase):
total_tokens: int
request_model: str
model_used: str
tokenizer_type: str

View file

@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
def common_checks( def common_checks(
request_body: dict, request_body: dict,
team_object: LiteLLM_TeamTable, team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable], user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
@ -45,13 +45,14 @@ def common_checks(
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object.blocked == True: if team_object is not None and team_object.blocked == True:
raise Exception( raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
) )
# 2. If user can call model # 2. If user can call model
if ( if (
_model is not None _model is not None
and team_object is not None
and len(team_object.models) > 0 and len(team_object.models) > 0
and _model not in team_object.models and _model not in team_object.models
): ):
@ -65,7 +66,8 @@ def common_checks(
) )
# 3. If team is in budget # 3. If team is in budget
if ( if (
team_object.max_budget is not None team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None and team_object.spend is not None
and team_object.spend > team_object.max_budget and team_object.spend > team_object.max_budget
): ):
@ -239,6 +241,7 @@ async def get_user_object(
user_id: str, user_id: str,
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
user_id_upsert: bool,
) -> Optional[LiteLLM_UserTable]: ) -> Optional[LiteLLM_UserTable]:
""" """
- Check if user id in proxy User Table - Check if user id in proxy User Table
@ -252,7 +255,7 @@ async def get_user_object(
return None return None
# check if in cache # check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id) cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None: if cached_user_obj is not None:
if isinstance(cached_user_obj, dict): if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj) return LiteLLM_UserTable(**cached_user_obj)
@ -260,16 +263,27 @@ async def get_user_object(
return cached_user_obj return cached_user_obj
# else, check db # else, check db
try: try:
response = await prisma_client.db.litellm_usertable.find_unique( response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id} where={"user_id": user_id}
) )
if response is None: if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
data={"user_id": user_id}
)
else:
raise Exception raise Exception
return LiteLLM_UserTable(**response.dict()) _response = LiteLLM_UserTable(**dict(response))
except Exception as e: # if end-user not in db
raise Exception( # save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
return _response
except Exception as e: # if user not in db
raise ValueError(
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call." f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
) )
@ -290,7 +304,7 @@ async def get_team_object(
) )
# check if in cache # check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id) cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None: if cached_team_obj is not None:
if isinstance(cached_team_obj, dict): if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj) return LiteLLM_TeamTable(**cached_team_obj)
@ -305,7 +319,11 @@ async def get_team_object(
if response is None: if response is None:
raise Exception raise Exception
return LiteLLM_TeamTable(**response.dict()) _response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
return _response
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."

View file

@ -55,12 +55,9 @@ class JWTHandler:
return True return True
return False return False
def is_team(self, scopes: list) -> bool: def get_end_user_id(
if self.litellm_jwtauth.team_jwt_scope in scopes: self, token: dict, default_value: Optional[str]
return True ) -> Optional[str]:
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try: try:
if self.litellm_jwtauth.end_user_id_jwt_field is not None: if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
@ -70,13 +67,36 @@ class JWTHandler:
user_id = default_value user_id = default_value
return user_id return user_id
def is_required_team_id(self) -> bool:
"""
Returns:
- True: if 'team_id_jwt_field' is set
- False: if not
"""
if self.litellm_jwtauth.team_id_jwt_field is None:
return False
return True
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.team_id_jwt_field is not None:
team_id = token[self.litellm_jwtauth.team_id_jwt_field] team_id = token[self.litellm_jwtauth.team_id_jwt_field]
elif self.litellm_jwtauth.team_id_default is not None:
team_id = self.litellm_jwtauth.team_id_default
else:
team_id = None
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
def is_upsert_user_id(self) -> bool:
"""
Returns:
- True: if 'user_id_upsert' is set
- False: if not
"""
return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.user_id_jwt_field is not None: if self.litellm_jwtauth.user_id_jwt_field is not None:
@ -207,12 +227,14 @@ class JWTHandler:
raise Exception(f"Validation fails: {str(e)}") raise Exception(f"Validation fails: {str(e)}")
elif public_key is not None and isinstance(public_key, str): elif public_key is not None and isinstance(public_key, str):
try: try:
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) cert = x509.load_pem_x509_certificate(
public_key.encode(), default_backend()
)
# Extract public key # Extract public key
key = cert.public_key().public_bytes( key = cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo serialization.PublicFormat.SubjectPublicKeyInfo,
) )
# decode the token using the public key # decode the token using the public key
@ -221,7 +243,7 @@ class JWTHandler:
key, key,
algorithms=algorithms, algorithms=algorithms,
audience=audience, audience=audience,
options=decode_options options=decode_options,
) )
return payload return payload

View file

@ -1,10 +1,7 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv
import os import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:

View file

@ -0,0 +1,147 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from typing import Optional
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions,
AnalyzeTextOutputType,
)
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error as e:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
traceback.print_exc()
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content, source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")

View file

@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests = user_api_key_dict.max_parallel_requests max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None: if max_parallel_requests is None:
max_parallel_requests = sys.maxsize max_parallel_requests = sys.maxsize
global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None: if tpm_limit is None:
tpm_limit = sys.maxsize tpm_limit = sys.maxsize
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values # Setup values
# ------------ # ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await cache.async_get_cache(
key=_key, local_only=True
)
# check if below limit
if current_global_requests is None:
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
# if below -> increment
else:
await cache.async_increment_cache(key=_key, value=1, local_only=True)
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H") current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M") current_minute = datetime.now().strftime("%M")
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values # Setup values
# ------------ # ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H") current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M") current_minute = datetime.now().strftime("%M")
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook") self.print_verbose(f"Inside Max Parallel Request Failure Hook")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = ( user_api_key = (
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None) kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
) )
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return return
## decrement call count if call failed ## decrement call count if call failed
if ( if "Max parallel request limit reached" in str(kwargs["exception"]):
hasattr(kwargs["exception"], "status_code")
and kwargs["exception"].status_code == 429
and "Max parallel request limit reached" in str(kwargs["exception"])
):
pass # ignore failed calls due to max limit being reached pass # ignore failed calls due to max limit being reached
else: else:
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = (
await self.user_api_key_cache.async_get_cache(
key=_key, local_only=True
)
)
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H") current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M") current_minute = datetime.now().strftime("%M")

View file

@ -11,7 +11,9 @@ sys.path.append(os.getcwd())
config_filename = "litellm.secrets" config_filename = "litellm.secrets"
load_dotenv() litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
if litellm_mode == "DEV":
load_dotenv()
from importlib import resources from importlib import resources
import shutil import shutil

View file

@ -4,11 +4,20 @@ model_list:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: llama3
litellm_params:
model: groq/llama3-8b-8192
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: openai/* model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: my-triton-model
litellm_params:
model: triton/any"
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
general_settings: general_settings:
store_model_in_db: true store_model_in_db: true
@ -17,4 +26,10 @@ general_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"] failure_callback: ["langfuse"]
default_team_settings:
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"

File diff suppressed because it is too large Load diff

View file

@ -140,6 +140,8 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long_callback self.slack_alerting_instance.response_taking_too_long_callback
) )
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.utils._init_custom_logger_compatible_class(callback)
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
if callback not in litellm.success_callback: if callback not in litellm.success_callback:
@ -252,8 +254,8 @@ class ProxyLogging:
""" """
Runs the CustomLogger's async_moderation_hook() Runs the CustomLogger's async_moderation_hook()
""" """
for callback in litellm.callbacks:
new_data = copy.deepcopy(data) new_data = copy.deepcopy(data)
for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
await callback.async_moderation_hook( await callback.async_moderation_hook(
@ -418,9 +420,14 @@ class ProxyLogging:
Related issue - https://github.com/BerriAI/litellm/issues/3395 Related issue - https://github.com/BerriAI/litellm/issues/3395
""" """
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
exception_str = str(original_exception)
if litellm_debug_info is not None:
exception_str += litellm_debug_info
asyncio.create_task( asyncio.create_task(
self.alerting_handler( self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}", message=f"LLM API call failed: {exception_str}",
level="High", level="High",
alert_type="llm_exceptions", alert_type="llm_exceptions",
request_data=request_data, request_data=request_data,
@ -1787,7 +1794,9 @@ def hash_token(token: str):
return hashed_token return hashed_token
def get_logging_payload(kwargs, response_obj, start_time, end_time): def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
):
from litellm.proxy._types import LiteLLM_SpendLogs from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json from pydantic import Json
import uuid import uuid
@ -1865,7 +1874,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
"prompt_tokens": usage.get("prompt_tokens", 0), "prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0),
"request_tags": metadata.get("tags", []), "request_tags": metadata.get("tags", []),
"end_user": kwargs.get("user", ""), "end_user": end_user_id or "",
"api_base": litellm_params.get("api_base", ""), "api_base": litellm_params.get("api_base", ""),
} }
@ -2028,6 +2037,11 @@ async def update_spend(
raise e raise e
### UPDATE END-USER TABLE ### ### UPDATE END-USER TABLE ###
verbose_proxy_logger.debug(
"End-User Spend transactions: {}".format(
len(prisma_client.end_user_list_transactons.keys())
)
)
if len(prisma_client.end_user_list_transactons.keys()) > 0: if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time() start_time = time.time()
@ -2043,13 +2057,18 @@ async def update_spend(
max_end_user_budget = None max_end_user_budget = None
if litellm.max_end_user_budget is not None: if litellm.max_end_user_budget is not None:
max_end_user_budget = litellm.max_end_user_budget max_end_user_budget = litellm.max_end_user_budget
new_user_obj = LiteLLM_EndUserTable( batcher.litellm_endusertable.upsert(
user_id=end_user_id, spend=response_cost, blocked=False
)
batcher.litellm_endusertable.update_many(
where={"user_id": end_user_id}, where={"user_id": end_user_id},
data={"spend": {"increment": response_cost}}, data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
) )
prisma_client.end_user_list_transactons = ( prisma_client.end_user_list_transactons = (
{} {}
) # Clear the remaining transactions after processing all batches in the loop. ) # Clear the remaining transactions after processing all batches in the loop.

View file

@ -9,7 +9,8 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -46,8 +47,10 @@ from litellm.types.router import (
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy, RetryPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
class Router: class Router:
@ -60,7 +63,7 @@ class Router:
def __init__( def __init__(
self, self,
model_list: Optional[list] = None, model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## CACHING ## ## CACHING ##
redis_url: Optional[str] = None, redis_url: Optional[str] = None,
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
@ -81,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
@ -256,7 +262,22 @@ class Router:
self.retry_after = retry_after self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
## SETTING FALLBACKS ##
### validate if it's set + in correct format
_fallbacks = fallbacks or litellm.fallbacks
self.validate_fallbacks(fallback_param=_fallbacks)
### set fallbacks
self.fallbacks = _fallbacks
if default_fallbacks is not None or litellm.default_fallbacks is not None:
_fallbacks = default_fallbacks or litellm.default_fallbacks
if self.fallbacks is not None:
self.fallbacks.append({"*": _fallbacks})
else:
self.fallbacks = [{"*": _fallbacks}]
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
@ -324,6 +345,21 @@ class Router:
if self.alerting_config is not None: if self.alerting_config is not None:
self._initialize_alerting() self._initialize_alerting()
def validate_fallbacks(self, fallback_param: Optional[List]):
if fallback_param is None:
return
if len(fallback_param) > 0: # if set
## for dictionary in list, check if only 1 key in dict
for _dict in fallback_param:
assert isinstance(_dict, dict), "Item={}, not a dictionary".format(
_dict
)
assert (
len(_dict.keys()) == 1
), "Only 1 key allows in dictionary. You set={} for dict={}".format(
len(_dict.keys()), _dict
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy": if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler( self.leastbusy_logger = LeastBusyLoggingHandler(
@ -468,12 +504,30 @@ class Router:
) )
raise e raise e
# fmt: off
@overload
async def acompletion( async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
):
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
@ -605,6 +659,33 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
):
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
def image_generation(self, prompt: str, model: str, **kwargs): def image_generation(self, prompt: str, model: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model
@ -1385,7 +1466,7 @@ class Router:
verbose_router_logger.debug(f"Trying to fallback b/w models") verbose_router_logger.debug(f"Trying to fallback b/w models")
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1416,18 +1497,29 @@ class Router:
response = await self.async_function_with_retries( response = await self.async_function_with_retries(
*args, **kwargs *args, **kwargs
) )
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response return response
except Exception as e: except Exception as e:
pass pass
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks: generic_fallback_idx: Optional[int] = None
key_list = list(item.keys()) ## check for specific model group-specific fallbacks
if len(key_list) == 0: for idx, item in enumerate(fallbacks):
continue if list(item.keys())[0] == model_group:
if key_list[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
verbose_router_logger.info( verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1450,6 +1542,9 @@ class Router:
response = await self.async_function_with_fallbacks( response = await self.async_function_with_fallbacks(
*args, **kwargs *args, **kwargs
) )
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response return response
except Exception as e: except Exception as e:
raise e raise e
@ -1479,22 +1574,30 @@ class Router:
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error """
if ( Retry Logic
isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
) or (
isinstance(original_exception, openai.RateLimitError)
and fallbacks is not None
):
raise original_exception
### RETRY
_timeout = self._router_should_retry( """
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model") or "",
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
# sleeps for the length of the timeout
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
if ( if (
@ -1528,10 +1631,14 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
try: try:
@ -1540,17 +1647,57 @@ class Router:
pass pass
raise original_exception raise original_exception
def should_retry_this_error(
self,
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
- there are no healthy deployments in the same model group
"""
_num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError
):
if _num_healthy_deployments <= 0:
raise error
return True
def function_with_fallbacks(self, *args, **kwargs): def function_with_fallbacks(self, *args, **kwargs):
""" """
Try calling the function_with_retries Try calling the function_with_retries
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs) response = self.function_with_retries(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -1559,7 +1706,7 @@ class Router:
try: try:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1601,10 +1748,20 @@ class Router:
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in fallbacks: generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception
@ -1628,12 +1785,27 @@ class Router:
raise e raise e
raise original_exception raise original_exception
def _router_should_retry( def _time_to_sleep_before_retry(
self, e: Exception, remaining_retries: int, num_retries: int self,
e: Exception,
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
) -> Union[int, float]: ) -> Union[int, float]:
""" """
Calculate back-off, then retry Calculate back-off, then retry
It should instantly retry only when:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
""" """
if (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 0
):
return 0
if hasattr(e, "response") and hasattr(e.response, "headers"): if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
@ -1670,23 +1842,29 @@ class Router:
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ( _healthy_deployments = self._get_healthy_deployments(
isinstance(original_exception, litellm.ContextWindowExceededError) model=kwargs.get("model"),
and context_window_fallbacks is not None )
) or (
isinstance(original_exception, openai.RateLimitError) # raises an exception if this error should not be retries
and fallbacks is not None self.should_retry_this_error(
): error=e,
raise original_exception healthy_deployments=_healthy_deployments,
## LOGGING context_window_fallbacks=context_window_fallbacks,
if num_retries > 0: )
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY # decides how long to sleep before retry
_timeout = self._router_should_retry( _timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=num_retries, remaining_retries=num_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout) time.sleep(_timeout)
for current_attempt in range(num_retries): for current_attempt in range(num_retries):
verbose_router_logger.debug( verbose_router_logger.debug(
@ -1700,11 +1878,15 @@ class Router:
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry( _timeout = self._time_to_sleep_before_retry(
e=e, e=e,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments,
) )
time.sleep(_timeout) time.sleep(_timeout)
raise original_exception raise original_exception
@ -1804,6 +1986,45 @@ class Router:
key=rpm_key, value=request_count, local_only=True key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl ) # don't change existing ttl
def _is_cooldown_required(self, exception_status: Union[str, int]):
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
try:
if isinstance(exception_status, str):
exception_status = int(exception_status)
if exception_status >= 400 and exception_status < 500:
if exception_status == 429:
# Cool down 429 Rate Limit Errors
return True
elif exception_status == 401:
# Cool down 401 Auth Errors
return True
elif exception_status == 408:
return True
else:
# Do NOT cool down all other 4XX Errors
return False
else:
# should cool down for all other errors
return True
except:
# Catch all - if any exceptions default to cooling down
return True
def _set_cooldown_deployments( def _set_cooldown_deployments(
self, exception_status: Union[str, int], deployment: Optional[str] = None self, exception_status: Union[str, int], deployment: Optional[str] = None
): ):
@ -1817,6 +2038,9 @@ class Router:
if deployment is None: if deployment is None:
return return
if self._is_cooldown_required(exception_status=exception_status) == False:
return
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
# get current fails for deployment # get current fails for deployment
@ -1907,6 +2131,47 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
async def _async_get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
def routing_strategy_pre_call_checks(self, deployment: dict): def routing_strategy_pre_call_checks(self, deployment: dict):
""" """
Mimics 'async_routing_strategy_pre_call_checks' Mimics 'async_routing_strategy_pre_call_checks'
@ -2115,6 +2380,10 @@ class Router:
raise ValueError( raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
) )
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None: if api_version is None:
api_version = "2023-07-01-preview" api_version = "2023-07-01-preview"
@ -2126,6 +2395,7 @@ class Router:
cache_key = f"{model_id}_async_client" cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( _client = openai.AsyncAzureOpenAI(
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
@ -2150,6 +2420,7 @@ class Router:
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
@ -2174,6 +2445,7 @@ class Router:
cache_key = f"{model_id}_stream_async_client" cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore _client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
@ -2198,6 +2470,7 @@ class Router:
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
@ -2230,6 +2503,7 @@ class Router:
"api_key": api_key, "api_key": api_key,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"api_version": api_version, "api_version": api_version,
"azure_ad_token": azure_ad_token,
} }
from litellm.llms.azure import select_azure_base_url_or_endpoint from litellm.llms.azure import select_azure_base_url_or_endpoint
@ -2329,7 +2603,7 @@ class Router:
) # cache for 1 hr ) # cache for 1 hr
else: else:
_api_key = api_key _api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str): if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key # only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15 _api_key = _api_key[:8] + "*" * 15
@ -2557,16 +2831,25 @@ class Router:
# init OpenAI, Azure clients # init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True)) self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model) # set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True:
print("Auto inferring region") # noqa
"""
Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up
"""
try: try:
if "azure" in deployment.litellm_params.model: if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region( region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None litellm_params=deployment.litellm_params, mode=None
) )
deployment.litellm_params.region_name = region deployment.litellm_params.region_name = region
except Exception as e: except Exception as e:
verbose_router_logger.error( verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format( "Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e) deployment.litellm_params.model, str(e)
) )
@ -2600,7 +2883,7 @@ class Router:
self.model_names.append(deployment.model_name) self.model_names.append(deployment.model_name)
return deployment return deployment
def upsert_deployment(self, deployment: Deployment) -> Deployment: def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
""" """
Add or update deployment Add or update deployment
Parameters: Parameters:
@ -2610,8 +2893,17 @@ class Router:
- The added/updated deployment - The added/updated deployment
""" """
# check if deployment already exists # check if deployment already exists
_deployment_model_id = deployment.model_info.id or ""
_deployment_on_router: Optional[Deployment] = self.get_deployment(
model_id=_deployment_model_id
)
if _deployment_on_router is not None:
# deployment with this model_id exists on the router
if deployment.litellm_params == _deployment_on_router.litellm_params:
# No need to update
return None
if deployment.model_info.id in self.get_model_ids(): # if there is a new litellm param -> then update the deployment
# remove the previous deployment # remove the previous deployment
removal_idx: Optional[int] = None removal_idx: Optional[int] = None
for idx, model in enumerate(self.model_list): for idx, model in enumerate(self.model_list):
@ -2620,16 +2912,9 @@ class Router:
if removal_idx is not None: if removal_idx is not None:
self.model_list.pop(removal_idx) self.model_list.pop(removal_idx)
else:
# add to model list # if the model_id is not in router
_deployment = deployment.to_json(exclude_none=True) self.add_deployment(deployment=deployment)
self.model_list.append(_deployment)
# initialize client
self._add_deployment(deployment=deployment)
# add to model names
self.model_names.append(deployment.model_name)
return deployment return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]: def delete_deployment(self, id: str) -> Optional[Deployment]:
@ -2942,7 +3227,7 @@ class Router:
): ):
# check if in allowed_model_region # check if in allowed_model_region
if ( if (
_is_region_eu(model_region=_litellm_params["region_name"]) _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False == False
): ):
invalid_model_indices.append(idx) invalid_model_indices.append(idx)
@ -2966,7 +3251,7 @@ class Router:
if _rate_limit_error == True: # allow generic fallback logic to take place if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}" f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
) )
elif _context_window_error == True: elif _context_window_error == True:
raise litellm.ContextWindowExceededError( raise litellm.ContextWindowExceededError(
@ -2990,11 +3275,15 @@ class Router:
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None, input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False, specific_deployment: Optional[bool] = False,
): ) -> Tuple[str, Union[list, dict]]:
""" """
Common checks for 'get_available_deployment' across sync + async call. Common checks for 'get_available_deployment' across sync + async call.
If 'healthy_deployments' returned is None, this means the user chose a specific deployment If 'healthy_deployments' returned is None, this means the user chose a specific deployment
Returns
- Dict, if specific model chosen
- List, if multiple models chosen
""" """
# check if aliases set on litellm model alias map # check if aliases set on litellm model alias map
if specific_deployment == True: if specific_deployment == True:
@ -3004,7 +3293,7 @@ class Router:
if deployment_model == model: if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name # return the first deployment where the `model` matches the specificed deployment name
return deployment, None return deployment_model, deployment
raise ValueError( raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
) )
@ -3020,7 +3309,7 @@ class Router:
self.default_deployment self.default_deployment
) # self.default_deployment ) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model updated_deployment["litellm_params"]["model"] = model
return updated_deployment, None return model, updated_deployment
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
@ -3034,7 +3323,9 @@ class Router:
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError(f"No healthy deployment available, passed model={model}. ") raise ValueError(
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
)
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[ model = litellm.model_alias_map[
model model
@ -3073,10 +3364,10 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
specific_deployment=specific_deployment, specific_deployment=specific_deployment,
) ) # type: ignore
if healthy_deployments is None: if isinstance(healthy_deployments, dict):
return model return healthy_deployments
# filter out the deployments currently cooling down # filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
@ -3095,13 +3386,12 @@ class Router:
healthy_deployments.remove(deployment) healthy_deployments.remove(deployment)
# filter pre-call checks # filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
_allowed_model_region = ( _allowed_model_region = (
request_kwargs.get("allowed_model_region") request_kwargs.get("allowed_model_region")
if request_kwargs is not None if request_kwargs is not None
else None else None
) )
if self.enable_pre_call_checks and messages is not None:
if _allowed_model_region == "eu": if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks( healthy_deployments = self._pre_call_checks(
model=model, model=model,
@ -3122,8 +3412,10 @@ class Router:
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}" f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
) )
if ( if (
@ -3132,7 +3424,7 @@ class Router:
): ):
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments( deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3142,7 +3434,7 @@ class Router:
): ):
deployment = await self.lowestcost_logger.async_get_available_deployments( deployment = await self.lowestcost_logger.async_get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3191,7 +3483,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}" f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3220,8 +3512,8 @@ class Router:
specific_deployment=specific_deployment, specific_deployment=specific_deployment,
) )
if healthy_deployments is None: if isinstance(healthy_deployments, dict):
return model return healthy_deployments
# filter out the deployments currently cooling down # filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
@ -3245,7 +3537,7 @@ class Router:
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments( deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments model_group=model, healthy_deployments=healthy_deployments # type: ignore
) )
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
@ -3293,7 +3585,7 @@ class Router:
): ):
deployment = self.lowestlatency_logger.get_available_deployments( deployment = self.lowestlatency_logger.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
elif ( elif (
@ -3302,7 +3594,7 @@ class Router:
): ):
deployment = self.lowesttpm_logger.get_available_deployments( deployment = self.lowesttpm_logger.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3312,7 +3604,7 @@ class Router:
): ):
deployment = self.lowesttpm_logger_v2.get_available_deployments( deployment = self.lowesttpm_logger_v2.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3321,7 +3613,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}" f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3483,7 +3775,7 @@ class Router:
) )
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_alert( proxy_logging_obj.slack_alerting_instance.send_alert(
message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}", message=f"Router: Cooling down Deployment:\nModel Name: {_model_name}\nAPI Base: {_api_base}\n{self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
alert_type="cooldown_deployment", alert_type="cooldown_deployment",
level="Low", level="Low",
) )

Some files were not shown because too many files have changed in this diff Show more