mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'main' into litellm_end_user_obj
This commit is contained in:
commit
0a775821db
178 changed files with 11955 additions and 2346 deletions
|
@ -58,6 +58,8 @@ jobs:
|
|||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
pip install prometheus-client==0.20.0
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -198,6 +200,7 @@ jobs:
|
|||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e AUTO_INFER_REGION=True \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||
|
|
10
.git-blame-ignore-revs
Normal file
10
.git-blame-ignore-revs
Normal file
|
@ -0,0 +1,10 @@
|
|||
# Add the commit hash of any commit you want to ignore in `git blame` here.
|
||||
# One commit hash per line.
|
||||
#
|
||||
# The GitHub Blame UI will use this file automatically!
|
||||
#
|
||||
# Run this command to always ignore formatting commits in `git blame`
|
||||
# git config blame.ignoreRevsFile .git-blame-ignore-revs
|
||||
|
||||
# Update pydantic code to fix warnings (GH-3600)
|
||||
876840e9957bc7e9f7d6a2b58c4d7c53dad16481
|
22
.github/pull_request_template.md
vendored
22
.github/pull_request_template.md
vendored
|
@ -1,6 +1,3 @@
|
|||
<!-- This is just examples. You can remove all items if you want. -->
|
||||
<!-- Please remove all comments. -->
|
||||
|
||||
## Title
|
||||
|
||||
<!-- e.g. "Implement user authentication feature" -->
|
||||
|
@ -18,7 +15,6 @@
|
|||
🐛 Bug Fix
|
||||
🧹 Refactoring
|
||||
📖 Documentation
|
||||
💻 Development Environment
|
||||
🚄 Infrastructure
|
||||
✅ Test
|
||||
|
||||
|
@ -26,22 +22,8 @@
|
|||
|
||||
<!-- List of changes -->
|
||||
|
||||
## Testing
|
||||
## [REQUIRED] Testing - Attach a screenshot of any new tests passing locall
|
||||
If UI changes, send a screenshot/GIF of working UI fixes
|
||||
|
||||
<!-- Test procedure -->
|
||||
|
||||
## Notes
|
||||
|
||||
<!-- Test results -->
|
||||
|
||||
<!-- Points to note for the reviewer, consultation content, concerns -->
|
||||
|
||||
## Pre-Submission Checklist (optional but appreciated):
|
||||
|
||||
- [ ] I have included relevant documentation updates (stored in /docs/my-website)
|
||||
|
||||
## OS Tests (optional but appreciated):
|
||||
|
||||
- [ ] Tested on Windows
|
||||
- [ ] Tested on MacOS
|
||||
- [ ] Tested on Linux
|
||||
|
|
187
cookbook/liteLLM_clarifai_Demo.ipynb
vendored
Normal file
187
cookbook/liteLLM_clarifai_Demo.ipynb
vendored
Normal file
|
@ -0,0 +1,187 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LiteLLM Clarifai \n",
|
||||
"This notebook walks you through on how to use liteLLM integration of Clarifai and call LLM model from clarifai with response in openAI output format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Pre-Requisites"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#install necessary packages\n",
|
||||
"!pip install litellm\n",
|
||||
"!pip install clarifai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To obtain Clarifai Personal Access Token follow the steps mentioned in the [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Set Clarifai Credentials\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CLARIFAI_API_KEY\"]= \"YOUR_CLARIFAI_PAT\" # Clarifai PAT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Mistral-large"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import litellm\n",
|
||||
"\n",
|
||||
"litellm.set_verbose=False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Mistral large response : ModelResponse(id='chatcmpl-6eed494d-7ae2-4870-b9c2-6a64d50a6151', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\"In the grand tapestry of time, where tales unfold,\\nLies the chronicle of ages, a sight to behold.\\nA tale of empires rising, and kings of old,\\nOf civilizations lost, and stories untold.\\n\\nOnce upon a yesterday, in a time so vast,\\nHumans took their first steps, casting shadows in the past.\\nFrom the cradle of mankind, a journey they embarked,\\nThrough stone and bronze and iron, their skills they sharpened and marked.\\n\\nEgyptians built pyramids, reaching for the skies,\\nWhile Greeks sought wisdom, truth, in philosophies that lie.\\nRoman legions marched, their empire to expand,\\nAnd in the East, the Silk Road joined the world, hand in hand.\\n\\nThe Middle Ages came, with knights in shining armor,\\nFeudal lords and serfs, a time of both clamor and calm order.\\nThen Renaissance bloomed, like a flower in the sun,\\nA rebirth of art and science, a new age had begun.\\n\\nAcross the vast oceans, explorers sailed with courage bold,\\nDiscovering new lands, stories of adventure, untold.\\nIndustrial Revolution churned, progress in its wake,\\nMachines and factories, a whole new world to make.\\n\\nTwo World Wars raged, a testament to man's strife,\\nYet from the ashes rose hope, a renewed will for life.\\nInto the modern era, technology took flight,\\nConnecting every corner, bathed in digital light.\\n\\nHistory, a symphony, a melody of time,\\nA testament to human will, resilience so sublime.\\nIn every page, a lesson, in every tale, a guide,\\nFor understanding our past, shapes our future's tide.\", role='assistant'))], created=1713896412, model='https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=13, completion_tokens=338, total_tokens=351))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from litellm import completion\n",
|
||||
"\n",
|
||||
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||
"response=completion(\n",
|
||||
" model=\"clarifai/mistralai.completion.mistral-large\",\n",
|
||||
" messages=messages,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(f\"Mistral large response : {response}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Claude-2.1 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Claude-2.1 response : ModelResponse(id='chatcmpl-d126c919-4db4-4aa3-ac8f-7edea41e0b93', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\" Here's a poem I wrote about history:\\n\\nThe Tides of Time\\n\\nThe tides of time ebb and flow,\\nCarrying stories of long ago.\\nFigures and events come into light,\\nShaping the future with all their might.\\n\\nKingdoms rise, empires fall, \\nLeaving traces that echo down every hall.\\nRevolutions bring change with a fiery glow,\\nToppling structures from long ago.\\n\\nExplorers traverse each ocean and land,\\nSeeking treasures they don't understand.\\nWhile artists and writers try to make their mark,\\nHoping their works shine bright in the dark.\\n\\nThe cycle repeats again and again,\\nAs humanity struggles to learn from its pain.\\nThough the players may change on history's stage,\\nThe themes stay the same from age to age.\\n\\nWar and peace, life and death,\\nLove and strife with every breath.\\nThe tides of time continue their dance,\\nAs we join in, by luck or by chance.\\n\\nSo we study the past to light the way forward, \\nHeeding warnings from stories told and heard.\\nThe future unfolds from this unending flow -\\nWhere the tides of time ultimately go.\", role='assistant'))], created=1713896579, model='https://api.clarifai.com/v2/users/anthropic/apps/completion/models/claude-2_1/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=232, total_tokens=244))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from litellm import completion\n",
|
||||
"\n",
|
||||
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||
"response=completion(\n",
|
||||
" model=\"clarifai/anthropic.completion.claude-2_1\",\n",
|
||||
" messages=messages,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(f\"Claude-2.1 response : {response}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### OpenAI GPT-4 (Streaming)\n",
|
||||
"Though clarifai doesn't support streaming, still you can call stream and get the response in standard StreamResponse format of liteLLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"In the quiet corners of time's grand hall,\\nLies the tale of rise and fall.\\nFrom ancient ruins to modern sprawl,\\nHistory, the greatest story of them all.\\n\\nEmpires have risen, empires have decayed,\\nThrough the eons, memories have stayed.\\nIn the book of time, history is laid,\\nA tapestry of events, meticulously displayed.\\n\\nThe pyramids of Egypt, standing tall,\\nThe Roman Empire's mighty sprawl.\\nFrom Alexander's conquest, to the Berlin Wall,\\nHistory, a silent witness to it all.\\n\\nIn the shadow of the past we tread,\\nWhere once kings and prophets led.\\nTheir stories in our hearts are spread,\\nEchoes of their words, in our minds are read.\\n\\nBattles fought and victories won,\\nActs of courage under the sun.\\nTales of love, of deeds done,\\nIn history's grand book, they all run.\\n\\nHeroes born, legends made,\\nIn the annals of time, they'll never fade.\\nTheir triumphs and failures all displayed,\\nIn the eternal march of history's parade.\\n\\nThe ink of the past is forever dry,\\nBut its lessons, we cannot deny.\\nIn its stories, truths lie,\\nIn its wisdom, we rely.\\n\\nHistory, a mirror to our past,\\nA guide for the future vast.\\nThrough its lens, we're ever cast,\\nIn the drama of life, forever vast.\", role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n",
|
||||
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from litellm import completion\n",
|
||||
"\n",
|
||||
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||
"response = completion(\n",
|
||||
" model=\"clarifai/openai.chat-completion.GPT-4\",\n",
|
||||
" messages=messages,\n",
|
||||
" stream=True,\n",
|
||||
" api_key = \"c75cc032415e45368be331fdd2c06db0\")\n",
|
||||
"\n",
|
||||
"for chunk in response:\n",
|
||||
" print(chunk)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Caching - In-Memory, Redis, s3, Redis Semantic Cache
|
||||
# Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py)
|
||||
|
||||
|
@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht
|
|||
|
||||
:::
|
||||
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache
|
||||
## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
@ -159,7 +159,7 @@ litellm.cache = Cache()
|
|||
# Make completion calls
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}]
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
response2 = completion(
|
||||
|
@ -174,6 +174,43 @@ response2 = completion(
|
|||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="disk" label="disk cache">
|
||||
|
||||
### Quick Start
|
||||
|
||||
Install diskcache:
|
||||
|
||||
```shell
|
||||
pip install diskcache
|
||||
```
|
||||
|
||||
Then you can use the disk cache as follows.
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.caching import Cache
|
||||
litellm.cache = Cache(type="disk")
|
||||
|
||||
# Make completion calls
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
caching=True
|
||||
)
|
||||
|
||||
# response1 == response2, response 1 is cached
|
||||
|
||||
```
|
||||
|
||||
If you run the code two times, response1 will use the cache from the first run that was stored in a cache file.
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
@ -191,13 +228,13 @@ Advanced Params
|
|||
|
||||
```python
|
||||
litellm.enable_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
@ -215,13 +252,13 @@ Update the Cache params
|
|||
|
||||
```python
|
||||
litellm.update_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
@ -276,22 +313,29 @@ cache.get_cache = get_cache
|
|||
```python
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"], # A list of litellm call types to cache for. Defaults to caching for all litellm call types.
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"],
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
|
||||
# redis cache params
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
||||
namespace: Optional[str] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
|
||||
# s3 Bucket, boto3 configuration
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_path: Optional[str] = None, # if you wish to save to a spefic path
|
||||
s3_path: Optional[str] = None, # if you wish to save to a specific path
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
|
@ -299,7 +343,11 @@ def __init__(
|
|||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
**kwargs,
|
||||
|
||||
# disk cache params
|
||||
disk_cache_dir=None,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
```
|
||||
|
|
@ -40,7 +40,7 @@ cache = Cache()
|
|||
|
||||
cache.add_cache(cache_key="test-key", result="1234")
|
||||
|
||||
cache.get_cache(cache_key="test-key)
|
||||
cache.get_cache(cache_key="test-key")
|
||||
```
|
||||
|
||||
## Caching with Streaming
|
||||
|
|
|
@ -4,6 +4,12 @@ LiteLLM allows you to:
|
|||
* Send 1 completion call to many models: Return Fastest Response
|
||||
* Send 1 completion call to many models: Return All Responses
|
||||
|
||||
:::info
|
||||
|
||||
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
|
||||
|
||||
:::
|
||||
|
||||
## Send multiple completion calls to 1 model
|
||||
|
||||
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
||||
|
|
|
@ -37,11 +37,12 @@ print(response) # ["max_tokens", "tools", "tool_choice", "stream"]
|
|||
|
||||
This is a list of openai params we translate across providers.
|
||||
|
||||
This list is constantly being updated.
|
||||
Use `litellm.get_supported_openai_params()` for an updated list of params for each model + provider
|
||||
|
||||
| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|
||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ | ✅ |
|
||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | ✅ | ✅ | ✅ | ✅ |
|
||||
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|
|
|
@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
|||
|
||||
## Custom mapping list
|
||||
|
||||
Base case - we return the original exception.
|
||||
Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
|
||||
|
||||
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| watsonx | | | | | | | |✓| | | |
|
||||
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
|
|
173
docs/my-website/docs/observability/lago.md
Normal file
173
docs/my-website/docs/observability/lago.md
Normal file
|
@ -0,0 +1,173 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Lago - Usage Based Billing
|
||||
|
||||
[Lago](https://www.getlago.com/) offers a self-hosted and cloud, metering and usage-based billing solution.
|
||||
|
||||
<Image img={require('../../img/lago.jpeg')} />
|
||||
|
||||
## Quick Start
|
||||
Use just 1 lines of code, to instantly log your responses **across all providers** with Lago
|
||||
|
||||
Get your Lago [API Key](https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key)
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["lago"] # logs cost + usage of successful calls to lago
|
||||
```
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
# pip install lago
|
||||
import litellm
|
||||
import os
|
||||
|
||||
os.environ["LAGO_API_BASE"] = "" # http://0.0.0.0:3000
|
||||
os.environ["LAGO_API_KEY"] = ""
|
||||
os.environ["LAGO_API_EVENT_CODE"] = "" # The billable metric's code - https://docs.getlago.com/guide/events/ingesting-usage#define-a-billable-metric
|
||||
|
||||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set lago as a callback, litellm will send the data to lago
|
||||
litellm.success_callback = ["lago"]
|
||||
|
||||
# openai call
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hi 👋 - i'm openai"}
|
||||
],
|
||||
user="your_customer_id" # 👈 SET YOUR CUSTOMER ID HERE
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add to Config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model
|
||||
model_name: fake-openai-endpoint
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["lago"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"user": "your-customer-id" # 👈 SET YOUR CUSTOMER ID
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai_python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
], user="my_customer_id") # 👈 whatever your customer id is
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "anything"
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
extra_body={
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
<Image img={require('../../img/lago_2.png')} />
|
||||
|
||||
## Advanced - Lagos Logging object
|
||||
|
||||
This is what LiteLLM will log to Lagos
|
||||
|
||||
```
|
||||
{
|
||||
"event": {
|
||||
"transaction_id": "<generated_unique_id>",
|
||||
"external_customer_id": <litellm_end_user_id>, # passed via `user` param in /chat/completion call - https://platform.openai.com/docs/api-reference/chat/create
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {
|
||||
"input_tokens": <number>,
|
||||
"output_tokens": <number>,
|
||||
"model": <string>,
|
||||
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
|
@ -136,6 +136,7 @@ response = completion(
|
|||
"existing_trace_id": "trace-id22",
|
||||
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
|
||||
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
|
||||
"debug_langfuse": True, # Will log the exact metadata sent to litellm for the trace/generation as `metadata_passed_to_litellm`
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -147,7 +148,7 @@ print(response)
|
|||
|
||||
#### Trace Specific Parameters
|
||||
|
||||
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default
|
||||
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead of `trace_id` if this is an existing trace, auto-generated by default
|
||||
* `trace_name` - Name of the trace, auto-generated by default
|
||||
* `session_id` - Session identifier for the trace, defaults to `None`
|
||||
* `trace_version` - Version for the trace, defaults to value for `version`
|
||||
|
@ -212,8 +213,20 @@ chat(messages)
|
|||
|
||||
## Redacting Messages, Response Content from Langfuse Logging
|
||||
|
||||
### Redact Messages and Responses from all Langfuse Logging
|
||||
|
||||
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
|
||||
|
||||
### Redact Messages and Responses from specific Langfuse Logging
|
||||
|
||||
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
|
||||
|
||||
Setting `mask_input` to `True` will mask the input from being logged for this call
|
||||
|
||||
Setting `mask_output` to `True` will make the output from being logged for this call.
|
||||
|
||||
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
|
||||
|
||||
## Troubleshooting & Errors
|
||||
### Data not getting logged to Langfuse ?
|
||||
- Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse
|
||||
|
|
|
@ -20,7 +20,7 @@ Use just 2 lines of code, to instantly log your responses **across all providers
|
|||
Get your OpenMeter API Key from https://openmeter.cloud/meters
|
||||
|
||||
```python
|
||||
litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls to openmeter
|
||||
litellm.callbacks = ["openmeter"] # logs cost + usage of successful calls to openmeter
|
||||
```
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls
|
|||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
# pip install langfuse
|
||||
# pip install openmeter
|
||||
import litellm
|
||||
import os
|
||||
|
||||
|
@ -39,8 +39,8 @@ os.environ["OPENMETER_API_KEY"] = ""
|
|||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set langfuse as a callback, litellm will send the data to langfuse
|
||||
litellm.success_callback = ["openmeter"]
|
||||
# set openmeter as a callback, litellm will send the data to openmeter
|
||||
litellm.callbacks = ["openmeter"]
|
||||
|
||||
# openai call
|
||||
response = litellm.completion(
|
||||
|
@ -64,7 +64,7 @@ model_list:
|
|||
model_name: fake-openai-endpoint
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["openmeter"] # 👈 KEY CHANGE
|
||||
callbacks: ["openmeter"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
|
177
docs/my-website/docs/providers/clarifai.md
Normal file
177
docs/my-website/docs/providers/clarifai.md
Normal file
|
@ -0,0 +1,177 @@
|
|||
|
||||
# Clarifai
|
||||
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
|
||||
|
||||
## Pre-Requisites
|
||||
|
||||
`pip install clarifai`
|
||||
|
||||
`pip install litellm`
|
||||
|
||||
## Required Environment Variables
|
||||
To obtain your Clarifai Personal access token follow this [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/). Optionally the PAT can also be passed in `completion` function.
|
||||
|
||||
```python
|
||||
os.environ["CALRIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["CLARIFAI_API_KEY"] = ""
|
||||
|
||||
response = completion(
|
||||
model="clarifai/mistralai.completion.mistral-large",
|
||||
messages=[{ "content": "Tell me a joke about physics?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
**Output**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-572701ee-9ab2-411c-ac75-46c1ba18e781",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 1,
|
||||
"message": {
|
||||
"content": "Sure, here's a physics joke for you:\n\nWhy can't you trust an atom?\n\nBecause they make up everything!",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1714410197,
|
||||
"model": "https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"usage": {
|
||||
"prompt_tokens": 14,
|
||||
"completion_tokens": 24,
|
||||
"total_tokens": 38
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Clarifai models
|
||||
liteLLM supports non-streaming requests to all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24)
|
||||
|
||||
Example Usage - Note: liteLLM supports all models deployed on Clarifai
|
||||
|
||||
## Llama LLMs
|
||||
| Model Name | Function Call |
|
||||
---------------------------|---------------------------------|
|
||||
| clarifai/meta.Llama-2.llama2-7b-chat | `completion('clarifai/meta.Llama-2.llama2-7b-chat', messages)`
|
||||
| clarifai/meta.Llama-2.llama2-13b-chat | `completion('clarifai/meta.Llama-2.llama2-13b-chat', messages)`
|
||||
| clarifai/meta.Llama-2.llama2-70b-chat | `completion('clarifai/meta.Llama-2.llama2-70b-chat', messages)` |
|
||||
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
|
||||
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
|
||||
|
||||
## Mistal LLMs
|
||||
| Model Name | Function Call |
|
||||
|---------------------------------------------|------------------------------------------------------------------------|
|
||||
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
|
||||
| clarifai/mistralai.completion.mistral-large | `completion('clarifai/mistralai.completion.mistral-large', messages)` |
|
||||
| clarifai/mistralai.completion.mistral-medium | `completion('clarifai/mistralai.completion.mistral-medium', messages)` |
|
||||
| clarifai/mistralai.completion.mistral-small | `completion('clarifai/mistralai.completion.mistral-small', messages)` |
|
||||
| clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1 | `completion('clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1', messages)`
|
||||
| clarifai/mistralai.completion.mistral-7B-OpenOrca | `completion('clarifai/mistralai.completion.mistral-7B-OpenOrca', messages)` |
|
||||
| clarifai/mistralai.completion.openHermes-2-mistral-7B | `completion('clarifai/mistralai.completion.openHermes-2-mistral-7B', messages)` |
|
||||
|
||||
|
||||
## Jurassic LLMs
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/ai21.complete.Jurassic2-Grande | `completion('clarifai/ai21.complete.Jurassic2-Grande', messages)` |
|
||||
| clarifai/ai21.complete.Jurassic2-Grande-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Grande-Instruct', messages)` |
|
||||
| clarifai/ai21.complete.Jurassic2-Jumbo-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Jumbo-Instruct', messages)` |
|
||||
| clarifai/ai21.complete.Jurassic2-Jumbo | `completion('clarifai/ai21.complete.Jurassic2-Jumbo', messages)` |
|
||||
| clarifai/ai21.complete.Jurassic2-Large | `completion('clarifai/ai21.complete.Jurassic2-Large', messages)` |
|
||||
|
||||
## Wizard LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/wizardlm.generate.wizardCoder-Python-34B | `completion('clarifai/wizardlm.generate.wizardCoder-Python-34B', messages)` |
|
||||
| clarifai/wizardlm.generate.wizardLM-70B | `completion('clarifai/wizardlm.generate.wizardLM-70B', messages)` |
|
||||
| clarifai/wizardlm.generate.wizardLM-13B | `completion('clarifai/wizardlm.generate.wizardLM-13B', messages)` |
|
||||
| clarifai/wizardlm.generate.wizardCoder-15B | `completion('clarifai/wizardlm.generate.wizardCoder-15B', messages)` |
|
||||
|
||||
## Anthropic models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/anthropic.completion.claude-v1 | `completion('clarifai/anthropic.completion.claude-v1', messages)` |
|
||||
| clarifai/anthropic.completion.claude-instant-1_2 | `completion('clarifai/anthropic.completion.claude-instant-1_2', messages)` |
|
||||
| clarifai/anthropic.completion.claude-instant | `completion('clarifai/anthropic.completion.claude-instant', messages)` |
|
||||
| clarifai/anthropic.completion.claude-v2 | `completion('clarifai/anthropic.completion.claude-v2', messages)` |
|
||||
| clarifai/anthropic.completion.claude-2_1 | `completion('clarifai/anthropic.completion.claude-2_1', messages)` |
|
||||
| clarifai/anthropic.completion.claude-3-opus | `completion('clarifai/anthropic.completion.claude-3-opus', messages)` |
|
||||
| clarifai/anthropic.completion.claude-3-sonnet | `completion('clarifai/anthropic.completion.claude-3-sonnet', messages)` |
|
||||
|
||||
## OpenAI GPT LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/openai.chat-completion.GPT-4 | `completion('clarifai/openai.chat-completion.GPT-4', messages)` |
|
||||
| clarifai/openai.chat-completion.GPT-3_5-turbo | `completion('clarifai/openai.chat-completion.GPT-3_5-turbo', messages)` |
|
||||
| clarifai/openai.chat-completion.gpt-4-turbo | `completion('clarifai/openai.chat-completion.gpt-4-turbo', messages)` |
|
||||
| clarifai/openai.completion.gpt-3_5-turbo-instruct | `completion('clarifai/openai.completion.gpt-3_5-turbo-instruct', messages)` |
|
||||
|
||||
## GCP LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/gcp.generate.gemini-1_5-pro | `completion('clarifai/gcp.generate.gemini-1_5-pro', messages)` |
|
||||
| clarifai/gcp.generate.imagen-2 | `completion('clarifai/gcp.generate.imagen-2', messages)` |
|
||||
| clarifai/gcp.generate.code-gecko | `completion('clarifai/gcp.generate.code-gecko', messages)` |
|
||||
| clarifai/gcp.generate.code-bison | `completion('clarifai/gcp.generate.code-bison', messages)` |
|
||||
| clarifai/gcp.generate.text-bison | `completion('clarifai/gcp.generate.text-bison', messages)` |
|
||||
| clarifai/gcp.generate.gemma-2b-it | `completion('clarifai/gcp.generate.gemma-2b-it', messages)` |
|
||||
| clarifai/gcp.generate.gemma-7b-it | `completion('clarifai/gcp.generate.gemma-7b-it', messages)` |
|
||||
| clarifai/gcp.generate.gemini-pro | `completion('clarifai/gcp.generate.gemini-pro', messages)` |
|
||||
| clarifai/gcp.generate.gemma-1_1-7b-it | `completion('clarifai/gcp.generate.gemma-1_1-7b-it', messages)` |
|
||||
|
||||
## Cohere LLMs
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/cohere.generate.cohere-generate-command | `completion('clarifai/cohere.generate.cohere-generate-command', messages)` |
|
||||
clarifai/cohere.generate.command-r-plus' | `completion('clarifai/clarifai/cohere.generate.command-r-plus', messages)`|
|
||||
|
||||
## Databricks LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/databricks.drbx.dbrx-instruct | `completion('clarifai/databricks.drbx.dbrx-instruct', messages)` |
|
||||
| clarifai/databricks.Dolly-v2.dolly-v2-12b | `completion('clarifai/databricks.Dolly-v2.dolly-v2-12b', messages)`|
|
||||
|
||||
## Microsoft LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/microsoft.text-generation.phi-2 | `completion('clarifai/microsoft.text-generation.phi-2', messages)` |
|
||||
| clarifai/microsoft.text-generation.phi-1_5 | `completion('clarifai/microsoft.text-generation.phi-1_5', messages)`|
|
||||
|
||||
## Salesforce models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------------------------------------------|-------------------------------------------------------------------------------|
|
||||
| clarifai/salesforce.blip.general-english-image-caption-blip-2 | `completion('clarifai/salesforce.blip.general-english-image-caption-blip-2', messages)` |
|
||||
| clarifai/salesforce.xgen.xgen-7b-8k-instruct | `completion('clarifai/salesforce.xgen.xgen-7b-8k-instruct', messages)` |
|
||||
|
||||
|
||||
## Other Top performing LLMs
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||
| clarifai/deci.decilm.deciLM-7B-instruct | `completion('clarifai/deci.decilm.deciLM-7B-instruct', messages)` |
|
||||
| clarifai/upstage.solar.solar-10_7b-instruct | `completion('clarifai/upstage.solar.solar-10_7b-instruct', messages)` |
|
||||
| clarifai/openchat.openchat.openchat-3_5-1210 | `completion('clarifai/openchat.openchat.openchat-3_5-1210', messages)` |
|
||||
| clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B | `completion('clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B', messages)` |
|
||||
| clarifai/fblgit.una-cybertron.una-cybertron-7b-v2 | `completion('clarifai/fblgit.una-cybertron.una-cybertron-7b-v2', messages)` |
|
||||
| clarifai/tiiuae.falcon.falcon-40b-instruct | `completion('clarifai/tiiuae.falcon.falcon-40b-instruct', messages)` |
|
||||
| clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat | `completion('clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat', messages)` |
|
||||
| clarifai/bigcode.code.StarCoder | `completion('clarifai/bigcode.code.StarCoder', messages)` |
|
||||
| clarifai/mosaicml.mpt.mpt-7b-instruct | `completion('clarifai/mosaicml.mpt.mpt-7b-instruct', messages)` |
|
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
|||
<Tabs>
|
||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||
|
||||
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -40,9 +45,58 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: wizard-coder
|
||||
litellm_params:
|
||||
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "wizard-coder",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||
|
||||
Append `conversational` to the model name
|
||||
|
||||
e.g. `huggingface/conversational/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/facebook/blenderbot-400M-distill",
|
||||
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.huggingface.cloud"
|
||||
)
|
||||
|
@ -62,7 +116,123 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: blenderbot
|
||||
litellm_params:
|
||||
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "blenderbot",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="classification" label="Text Classification">
|
||||
|
||||
Append `text-classification` to the model name
|
||||
|
||||
e.g. `huggingface/text-classification/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
# [OPTIONAL] set env var
|
||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||
|
||||
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||
|
||||
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bert-classifier
|
||||
litellm_params:
|
||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "bert-classifier",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||
|
||||
Append `text-generation` to the model name
|
||||
|
||||
e.g. `huggingface/text-generation/<model-name>`
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/roneneldan/TinyStories-3M",
|
||||
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||
messages=messages,
|
||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
|
|
@ -102,12 +102,18 @@ Ollama supported models: https://github.com/ollama/ollama
|
|||
| Model Name | Function Call |
|
||||
|----------------------|-----------------------------------------------------------------------------------
|
||||
| Mistral | `completion(model='ollama/mistral', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Mistral-7B-Instruct-v0.1 | `completion(model='ollama/mistral-7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mistral-7B-Instruct-v0.2 | `completion(model='ollama/mistral-7B-Instruct-v0.2', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mixtral-8x7B-Instruct-v0.1 | `completion(model='ollama/mistral-8x7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mixtral-8x22B-Instruct-v0.1 | `completion(model='ollama/mixtral-8x22B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Llama2 7B | `completion(model='ollama/llama2', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 13B | `completion(model='ollama/llama2:13b', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 70B | `completion(model='ollama/llama2:70b', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Code Llama | `completion(model='ollama/codellama', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
|Meta LLaMa3 8B | `completion(model='ollama/llama3', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Meta LLaMa3 70B | `completion(model='ollama/llama3:70b', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Orca Mini | `completion(model='ollama/orca-mini', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Vicuna | `completion(model='ollama/vicuna', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Nous-Hermes | `completion(model='ollama/nous-hermes', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
|
|
|
@ -20,7 +20,7 @@ os.environ["OPENAI_API_KEY"] = "your-api-key"
|
|||
|
||||
# openai call
|
||||
response = completion(
|
||||
model = "gpt-3.5-turbo",
|
||||
model = "gpt-4o",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
@ -163,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
|
|||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
|
||||
| gpt-4o-2024-05-13 | `response = completion(model="gpt-4o-2024-05-13", messages=messages)` |
|
||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||
|
@ -186,6 +188,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
|
|||
## OpenAI Vision Models
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
|
||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
|
||||
|
||||
|
|
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Triton Inference Server
|
||||
|
||||
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
|
||||
### Example Call
|
||||
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
response = await litellm.aembedding(
|
||||
model="triton/<your-triton-model>",
|
||||
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/<your-triton-model>"
|
||||
api_base: https://your-triton-api-base/triton/embeddings
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=["hello from litellm"],
|
||||
model="my-triton-model"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "my-triton-model",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
|
@ -364,6 +364,8 @@ response = completion(
|
|||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
||||
| gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
||||
| gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` |
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
# 🚨 Alerting
|
||||
|
||||
Get alerts for:
|
||||
|
||||
- Hanging LLM api calls
|
||||
- Failed LLM api calls
|
||||
- Slow LLM api calls
|
||||
- Budget Tracking per key/user:
|
||||
- When a User/Key crosses their Budget
|
||||
- When a User/Key is 15% away from crossing their Budget
|
||||
- Failed LLM api calls
|
||||
- Budget Tracking per key/user
|
||||
- Spend Reports - Weekly & Monthly spend per Team, Tag
|
||||
- Failed db read/writes
|
||||
- Daily Reports:
|
||||
- **LLM** Top 5 slowest deployments
|
||||
- **LLM** Top 5 deployments with most failed requests
|
||||
- **Spend** Weekly & Monthly spend per Team, Tag
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
@ -17,10 +22,12 @@ Set up a slack alert channel to receive alerts from proxy.
|
|||
|
||||
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||
|
||||
You can also use Discord Webhooks, see [here](#using-discord-webhooks)
|
||||
|
||||
### Step 2: Update config.yaml
|
||||
|
||||
Let's save a bad key to our proxy
|
||||
- Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
|
||||
- Just for testing purposes, let's save a bad key to our proxy.
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -33,16 +40,59 @@ general_settings:
|
|||
alerting: ["slack"]
|
||||
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||
|
||||
environment_variables:
|
||||
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
|
||||
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
|
||||
```
|
||||
|
||||
Set `SLACK_WEBHOOK_URL` in your proxy env
|
||||
|
||||
```shell
|
||||
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
|
||||
```
|
||||
|
||||
### Step 3: Start proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Testing Alerting is Setup Correctly
|
||||
|
||||
Make a GET request to `/health/services`, expect to see a test slack alert in your provided webhook slack channel
|
||||
|
||||
```shell
|
||||
curl -X GET 'http://localhost:4000/health/services?service=slack' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
|
||||
## Extras
|
||||
|
||||
### Using Discord Webhooks
|
||||
|
||||
Discord provides a slack compatible webhook url that you can use for alerting
|
||||
|
||||
##### Quick Start
|
||||
|
||||
1. Get a webhook url for your discord channel
|
||||
|
||||
2. Append `/slack` to your discord webhook - it should look like
|
||||
|
||||
```
|
||||
"https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
|
||||
```
|
||||
|
||||
3. Add it to your litellm config
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
model_name: "azure-model"
|
||||
litellm_params:
|
||||
model: "azure/gpt-35-turbo"
|
||||
api_key: "my-bad-key" # 👈 bad key
|
||||
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||
|
||||
environment_variables:
|
||||
SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
|
||||
```
|
||||
|
||||
That's it ! You're ready to go !
|
||||
|
|
229
docs/my-website/docs/proxy/billing.md
Normal file
229
docs/my-website/docs/proxy/billing.md
Normal file
|
@ -0,0 +1,229 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 💰 Billing
|
||||
|
||||
Bill users for their usage.
|
||||
|
||||
Requirements:
|
||||
- Setup a billing plan on Lago, for usage-based billing. We recommend following their Stripe tutorial - https://docs.getlago.com/templates/per-transaction/stripe#step-1-create-billable-metrics-for-transaction
|
||||
|
||||
Steps:
|
||||
- Connect the proxy to Lago
|
||||
- Set the id you want to bill for (customers, internal users, teams)
|
||||
- Start!
|
||||
|
||||
## 1. Connect proxy to Lago
|
||||
|
||||
Add your Lago keys to the environment
|
||||
|
||||
```bash
|
||||
export LAGO_API_BASE="http://localhost:3000" # self-host - https://docs.getlago.com/guide/self-hosted/docker#run-the-app
|
||||
export LAGO_API_KEY="3e29d607-de54-49aa-a019-ecf585729070" # Get key - https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key
|
||||
export LAGO_API_EVENT_CODE="openai_tokens" # name of lago billing code
|
||||
```
|
||||
|
||||
Set 'lago' as a callback on your proxy config.yaml
|
||||
|
||||
```yaml
|
||||
...
|
||||
litellm_settings:
|
||||
callbacks: ["lago"]
|
||||
```
|
||||
|
||||
## 2. Set the id you want to bill for
|
||||
|
||||
For:
|
||||
- Customers (id passed via 'user' param in /chat/completion call) = 'end_user_id'
|
||||
- Internal Users (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'user_id'
|
||||
- Teams (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'team_id'
|
||||
|
||||
```yaml
|
||||
export LAGO_API_CHARGE_BY="end_user_id" # 👈 Charge 'Customers'. Default is 'end_user_id'.
|
||||
```
|
||||
|
||||
## 3. Start billing!
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="customers" label="Customer Billing">
|
||||
|
||||
### **Curl**
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
### **OpenAI Python SDK**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
], user="my_customer_id") # 👈 whatever your customer id is
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
### **Langchain**
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "anything"
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
extra_body={
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="internal_user" label="Internal User (Key Owner) Billing">
|
||||
|
||||
1. Create a key for that user
|
||||
|
||||
```bash
|
||||
curl 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer <your-master-key>' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{"user_id": "my-unique-id"}'
|
||||
```
|
||||
|
||||
Response Object:
|
||||
|
||||
```bash
|
||||
{
|
||||
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
|
||||
}
|
||||
```
|
||||
|
||||
2. Make API Calls with that Key
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="teams" label="Team Billing">
|
||||
|
||||
1. Create a key for that team
|
||||
|
||||
```bash
|
||||
curl 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer <your-master-key>' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{"team_id": "my-unique-id"}'
|
||||
```
|
||||
|
||||
Response Object:
|
||||
|
||||
```bash
|
||||
{
|
||||
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
|
||||
}
|
||||
```
|
||||
|
||||
2. Make API Calls with that Key
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**See Results on Lago**
|
||||
|
||||
|
||||
<Image img={require('../../img/lago_2.png')} style={{ width: '500px', height: 'auto' }} />
|
||||
|
||||
## Advanced - Lago Logging object
|
||||
|
||||
This is what LiteLLM will log to Lagos
|
||||
|
||||
```
|
||||
{
|
||||
"event": {
|
||||
"transaction_id": "<generated_unique_id>",
|
||||
"external_customer_id": <selected_id>, # either 'end_user_id', 'user_id', or 'team_id'. Default 'end_user_id'.
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {
|
||||
"input_tokens": <number>,
|
||||
"output_tokens": <number>,
|
||||
"model": <string>,
|
||||
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
|
@ -1,8 +1,161 @@
|
|||
# Cost Tracking - Azure
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 💸 Spend Tracking
|
||||
|
||||
Track spend for keys, users, and teams across 100+ LLMs.
|
||||
|
||||
## Getting Spend Reports - To Charge Other Teams, API Keys
|
||||
|
||||
Use the `/global/spend/report` endpoint to get daily spend per team, with a breakdown of spend per API Key, Model
|
||||
|
||||
### Example Request
|
||||
|
||||
```shell
|
||||
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
### Example Response
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="response" label="Expected Response">
|
||||
|
||||
```shell
|
||||
[
|
||||
{
|
||||
"group_by_day": "2024-04-30T00:00:00+00:00",
|
||||
"teams": [
|
||||
{
|
||||
"team_name": "Prod Team",
|
||||
"total_spend": 0.0015265,
|
||||
"metadata": [ # see the spend by unique(key + model)
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"spend": 0.00123,
|
||||
"total_tokens": 28,
|
||||
"api_key": "88dc28.." # the hashed api key
|
||||
},
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"spend": 0.00123,
|
||||
"total_tokens": 28,
|
||||
"api_key": "a73dc2.." # the hashed api key
|
||||
},
|
||||
{
|
||||
"model": "chatgpt-v-2",
|
||||
"spend": 0.000214,
|
||||
"total_tokens": 122,
|
||||
"api_key": "898c28.." # the hashed api key
|
||||
},
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"spend": 0.0000825,
|
||||
"total_tokens": 85,
|
||||
"api_key": "84dc28.." # the hashed api key
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="py-script" label="Script to Parse Response (Python)">
|
||||
|
||||
```python
|
||||
import requests
|
||||
url = 'http://localhost:4000/global/spend/report'
|
||||
params = {
|
||||
'start_date': '2023-04-01',
|
||||
'end_date': '2024-06-30'
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Authorization': 'Bearer sk-1234'
|
||||
}
|
||||
|
||||
# Make the GET request
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
spend_report = response.json()
|
||||
|
||||
for row in spend_report:
|
||||
date = row["group_by_day"]
|
||||
teams = row["teams"]
|
||||
for team in teams:
|
||||
team_name = team["team_name"]
|
||||
total_spend = team["total_spend"]
|
||||
metadata = team["metadata"]
|
||||
|
||||
print(f"Date: {date}")
|
||||
print(f"Team: {team_name}")
|
||||
print(f"Total Spend: {total_spend}")
|
||||
print("Metadata: ", metadata)
|
||||
print()
|
||||
```
|
||||
|
||||
Output from script
|
||||
```shell
|
||||
# Date: 2024-05-11T00:00:00+00:00
|
||||
# Team: local_test_team
|
||||
# Total Spend: 0.003675099999999999
|
||||
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.003675099999999999, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 3105}]
|
||||
|
||||
# Date: 2024-05-13T00:00:00+00:00
|
||||
# Team: Unassigned Team
|
||||
# Total Spend: 3.4e-05
|
||||
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 3.4e-05, 'api_key': '9569d13c9777dba68096dea49b0b03e0aaf4d2b65d4030eda9e8a2733c3cd6e0', 'total_tokens': 50}]
|
||||
|
||||
# Date: 2024-05-13T00:00:00+00:00
|
||||
# Team: central
|
||||
# Total Spend: 0.000684
|
||||
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.000684, 'api_key': '0323facdf3af551594017b9ef162434a9b9a8ca1bbd9ccbd9d6ce173b1015605', 'total_tokens': 498}]
|
||||
|
||||
# Date: 2024-05-13T00:00:00+00:00
|
||||
# Team: local_test_team
|
||||
# Total Spend: 0.0005715000000000001
|
||||
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.0005715000000000001, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 423}]
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Reset Team, API Key Spend - MASTER KEY ONLY
|
||||
|
||||
Use `/global/spend/reset` if you want to:
|
||||
- Reset the Spend for all API Keys, Teams. The `spend` for ALL Teams and Keys in `LiteLLM_TeamTable` and `LiteLLM_VerificationToken` will be set to `spend=0`
|
||||
|
||||
- LiteLLM will maintain all the logs in `LiteLLMSpendLogs` for Auditing Purposes
|
||||
|
||||
### Request
|
||||
Only the `LITELLM_MASTER_KEY` you set can access this route
|
||||
```shell
|
||||
curl -X POST \
|
||||
'http://localhost:4000/global/spend/reset' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
### Expected Responses
|
||||
|
||||
```shell
|
||||
{"message":"Spend for all API Keys and Teams reset successfully","status":"success"}
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Spend Tracking for Azure
|
||||
|
||||
Set base model for cost tracking azure image-gen call
|
||||
|
||||
## Image Generation
|
||||
### Image Generation
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -17,7 +170,7 @@ model_list:
|
|||
mode: image_generation
|
||||
```
|
||||
|
||||
## Chat Completions / Embeddings
|
||||
### Chat Completions / Embeddings
|
||||
|
||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
|
|||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina
|
||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
|
||||
|
||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
||||
|
||||
|
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
|||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
||||
- [(BETA) Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
||||
|
||||
## Custom Callback Class [Async]
|
||||
Use this when you want to run custom callbacks in `python`
|
||||
|
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## (BETA) Moderation with Azure Content Safety
|
||||
|
||||
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
|
||||
|
||||
**Step 0** Deploy Azure Content Safety
|
||||
|
||||
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
|
||||
|
||||
**Step 1** Set Athina API key
|
||||
|
||||
```shell
|
||||
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
|
||||
```
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
callbacks: ["azure_content_safety"]
|
||||
azure_content_safety_params:
|
||||
endpoint: "<your-azure-content-safety-endpoint>"
|
||||
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, how are you?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
|
||||
The details of the response will describe :
|
||||
- The `source` : input text or llm generated text
|
||||
- The `category` : the category of the content that triggered the moderation
|
||||
- The `severity` : the severity from 0 to 10
|
||||
|
||||
**Step 4**: Customizing Azure Content Safety Thresholds
|
||||
|
||||
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
callbacks: ["azure_content_safety"]
|
||||
azure_content_safety_params:
|
||||
endpoint: "<your-azure-content-safety-endpoint>"
|
||||
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||
thresholds:
|
||||
Hate: 6
|
||||
SelfHarm: 8
|
||||
Sexual: 6
|
||||
Violence: 4
|
||||
```
|
||||
|
||||
:::info
|
||||
`thresholds` are not required by default, but you can tune the values to your needs.
|
||||
Default values is `4` for all categories
|
||||
:::
|
|
@ -64,6 +64,12 @@ router_settings:
|
|||
redis_password: os.environ/REDIS_PASSWORD
|
||||
```
|
||||
|
||||
## 4. Disable 'load_dotenv'
|
||||
|
||||
Set `export LITELLM_MODE="PRODUCTION"`
|
||||
|
||||
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
|
||||
|
||||
## Extras
|
||||
### Expected Performance in Production
|
||||
|
||||
|
|
|
@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
## Advanced - Context Window Fallbacks
|
||||
## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
|
||||
|
||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||
|
||||
|
@ -287,6 +287,69 @@ print(response)
|
|||
</Tabs>
|
||||
|
||||
|
||||
## Advanced - EU-Region Filtering (Pre-Call Checks)
|
||||
|
||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||
|
||||
Set 'region_name' of deployment.
|
||||
|
||||
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||
|
||||
**1. Set Config**
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
region_name: "eu" # 👈 SET EU-REGION
|
||||
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: gemini-pro
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-pro-1.5
|
||||
vertex_project: adroit-crow-1234
|
||||
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.with_raw_response.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [{"role": "user", "content": "Who was Alexander?"}]
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
print(f"response.headers.get('x-litellm-model-api-base')")
|
||||
```
|
||||
|
||||
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
```yaml
|
||||
|
|
|
@ -110,7 +110,7 @@ general_settings:
|
|||
admin_jwt_scope: "litellm-proxy-admin"
|
||||
```
|
||||
|
||||
## Advanced - Spend Tracking (User / Team / Org)
|
||||
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
|
||||
|
||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||
|
||||
|
@ -123,6 +123,7 @@ general_settings:
|
|||
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
|
||||
```
|
||||
|
||||
Expected JWT:
|
||||
|
@ -131,7 +132,7 @@ Expected JWT:
|
|||
{
|
||||
"client_id": "my-unique-team",
|
||||
"sub": "my-unique-user",
|
||||
"org_id": "my-unique-org"
|
||||
"org_id": "my-unique-org",
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -365,6 +365,188 @@ curl --location 'http://0.0.0.0:4000/moderations' \
|
|||
|
||||
## Advanced
|
||||
|
||||
### (BETA) Batch Completions - pass multiple models
|
||||
|
||||
Use this when you want to send 1 request to N Models
|
||||
|
||||
#### Expected Request Format
|
||||
|
||||
Pass model as a string of comma separated value of models. Example `"model"="llama3,gpt-3.5-turbo"`
|
||||
|
||||
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
|
||||
- `model_name="llama3"`
|
||||
- `model_name="gpt-3.5-turbo"`
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai-py" label="OpenAI Python SDK">
|
||||
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo,llama3",
|
||||
messages=[
|
||||
{"role": "user", "content": "this is a test request, write a short poem"}
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
|
||||
#### Expected Response Format
|
||||
|
||||
Get a list of responses when `model` is passed as a list
|
||||
|
||||
```python
|
||||
[
|
||||
ChatCompletion(
|
||||
id='chatcmpl-9NoYhS2G0fswot0b6QpoQgmRQMaIf',
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason='stop',
|
||||
index=0,
|
||||
logprobs=None,
|
||||
message=ChatCompletionMessage(
|
||||
content='In the depths of my soul, a spark ignites\nA light that shines so pure and bright\nIt dances and leaps, refusing to die\nA flame of hope that reaches the sky\n\nIt warms my heart and fills me with bliss\nA reminder that in darkness, there is light to kiss\nSo I hold onto this fire, this guiding light\nAnd let it lead me through the darkest night.',
|
||||
role='assistant',
|
||||
function_call=None,
|
||||
tool_calls=None
|
||||
)
|
||||
)
|
||||
],
|
||||
created=1715462919,
|
||||
model='gpt-3.5-turbo-0125',
|
||||
object='chat.completion',
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=83,
|
||||
prompt_tokens=17,
|
||||
total_tokens=100
|
||||
)
|
||||
),
|
||||
ChatCompletion(
|
||||
id='chatcmpl-4ac3e982-da4e-486d-bddb-ed1d5cb9c03c',
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason='stop',
|
||||
index=0,
|
||||
logprobs=None,
|
||||
message=ChatCompletionMessage(
|
||||
content="A test request, and I'm delighted!\nHere's a short poem, just for you:\n\nMoonbeams dance upon the sea,\nA path of light, for you to see.\nThe stars up high, a twinkling show,\nA night of wonder, for all to know.\n\nThe world is quiet, save the night,\nA peaceful hush, a gentle light.\nThe world is full, of beauty rare,\nA treasure trove, beyond compare.\n\nI hope you enjoyed this little test,\nA poem born, of whimsy and jest.\nLet me know, if there's anything else!",
|
||||
role='assistant',
|
||||
function_call=None,
|
||||
tool_calls=None
|
||||
)
|
||||
)
|
||||
],
|
||||
created=1715462919,
|
||||
model='groq/llama3-8b-8192',
|
||||
object='chat.completion',
|
||||
system_fingerprint='fp_a2c8d063cb',
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=120,
|
||||
prompt_tokens=20,
|
||||
total_tokens=140
|
||||
)
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
|
||||
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "llama3,gpt-3.5-turbo",
|
||||
"max_tokens": 10,
|
||||
"user": "litellm2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "is litellm getting better"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
#### Expected Response Format
|
||||
|
||||
Get a list of responses when `model` is passed as a list
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1715459876,
|
||||
"model": "groq/llama3-8b-8192",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "fp_179b0f92c9",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 12,
|
||||
"total_tokens": 22
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "TES4 could refer to The Elder Scrolls IV:",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1715459877,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 9,
|
||||
"total_tokens": 19
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### Pass User LLM API Keys, Fallbacks
|
||||
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
||||
|
||||
|
|
|
@ -653,7 +653,9 @@ from litellm import Router
|
|||
model_list = [{...}]
|
||||
|
||||
router = Router(model_list=model_list,
|
||||
allowed_fails=1) # cooldown model if it fails > 1 call in a minute.
|
||||
allowed_fails=1, # cooldown model if it fails > 1 call in a minute.
|
||||
cooldown_time=100 # cooldown the deployment for 100 seconds if it num_fails > allowed_fails
|
||||
)
|
||||
|
||||
user_message = "Hello, whats the weather in San Francisco??"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
@ -770,6 +772,8 @@ If the error is a context window exceeded error, fall back to a larger model gro
|
|||
|
||||
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
|
||||
|
||||
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad.
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
|
||||
|
@ -830,6 +834,7 @@ model_list = [
|
|||
|
||||
router = Router(model_list=model_list,
|
||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||
default_fallbacks=["gpt-3.5-turbo-16k"],
|
||||
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
||||
set_verbose=True)
|
||||
|
||||
|
@ -879,13 +884,11 @@ router = Router(model_list: Optional[list] = None,
|
|||
cache_responses=True)
|
||||
```
|
||||
|
||||
## Pre-Call Checks (Context Window)
|
||||
## Pre-Call Checks (Context Window, EU-Regions)
|
||||
|
||||
Enable pre-call checks to filter out:
|
||||
1. deployments with context window limit < messages for a call.
|
||||
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
|
||||
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
|
||||
])`)
|
||||
2. deployments outside of eu-region
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
@ -900,10 +903,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
|
|||
|
||||
**2. Set Model List**
|
||||
|
||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||
For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="same-group" label="Same Group">
|
||||
For 'eu-region' filtering, Set 'region_name' of deployment.
|
||||
|
||||
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
|
@ -914,10 +921,9 @@ model_list = [
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"region_name": "eu" # 👈 SET 'EU' REGION NAME
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
|
@ -926,54 +932,26 @@ model_list = [
|
|||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-pro",
|
||||
"litellm_params: {
|
||||
"model": "vertex_ai/gemini-pro-1.5",
|
||||
"vertex_project": "adroit-crow-1234",
|
||||
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-small", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-large", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-opus",
|
||||
"litellm_params": { call
|
||||
"model": "claude-3-opus-20240229",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="context-window-check" label="Context Window Check">
|
||||
|
||||
```python
|
||||
"""
|
||||
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
||||
|
@ -983,7 +961,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
|
|||
from litellm import Router
|
||||
import os
|
||||
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
|
@ -992,6 +969,7 @@ model_list = [
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
|
@ -1021,6 +999,59 @@ response = router.completion(
|
|||
print(f"response: {response}")
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="eu-region-check" label="EU Region Check">
|
||||
|
||||
```python
|
||||
"""
|
||||
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
|
||||
- Make a call
|
||||
- Assert it picks the eu-region model
|
||||
"""
|
||||
|
||||
from litellm import Router
|
||||
import os
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"region_name": "eu"
|
||||
},
|
||||
"model_info": {
|
||||
"id": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"model_info": {
|
||||
"id": "2"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Who was Alexander?"}],
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
print(f"response id: {response._hidden_params['model_id']}")
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
:::info
|
||||
|
@ -1283,10 +1314,11 @@ def __init__(
|
|||
num_retries: int = 0,
|
||||
timeout: Optional[float] = None,
|
||||
default_litellm_params={}, # default params for Router.chat.completion.create
|
||||
fallbacks: List = [],
|
||||
fallbacks: Optional[List] = None,
|
||||
default_fallbacks: Optional[List] = None
|
||||
allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown
|
||||
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
|
||||
context_window_fallbacks: List = [],
|
||||
context_window_fallbacks: Optional[List] = None,
|
||||
model_group_alias: Optional[dict] = {},
|
||||
retry_after: int = 0, # (min) time to wait before retrying a failed request
|
||||
routing_strategy: Literal[
|
||||
|
|
BIN
docs/my-website/img/lago.jpeg
Normal file
BIN
docs/my-website/img/lago.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 344 KiB |
BIN
docs/my-website/img/lago_2.png
Normal file
BIN
docs/my-website/img/lago_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 176 KiB |
|
@ -39,7 +39,9 @@ const sidebars = {
|
|||
"proxy/demo",
|
||||
"proxy/configs",
|
||||
"proxy/reliability",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/users",
|
||||
"proxy/billing",
|
||||
"proxy/user_keys",
|
||||
"proxy/enterprise",
|
||||
"proxy/virtual_keys",
|
||||
|
@ -52,7 +54,6 @@ const sidebars = {
|
|||
"proxy/team_based_routing",
|
||||
"proxy/customer_routing",
|
||||
"proxy/ui",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/token_auth",
|
||||
{
|
||||
type: "category",
|
||||
|
@ -134,6 +135,7 @@ const sidebars = {
|
|||
"providers/huggingface",
|
||||
"providers/watsonx",
|
||||
"providers/predibase",
|
||||
"providers/triton-inference-server",
|
||||
"providers/ollama",
|
||||
"providers/perplexity",
|
||||
"providers/groq",
|
||||
|
@ -174,6 +176,7 @@ const sidebars = {
|
|||
"observability/custom_callback",
|
||||
"observability/langfuse_integration",
|
||||
"observability/sentry",
|
||||
"observability/lago",
|
||||
"observability/openmeter",
|
||||
"observability/promptlayer_integration",
|
||||
"observability/wandb_integration",
|
||||
|
@ -188,7 +191,7 @@ const sidebars = {
|
|||
`observability/telemetry`,
|
||||
],
|
||||
},
|
||||
"caching/redis_cache",
|
||||
"caching/all_caches",
|
||||
{
|
||||
type: "category",
|
||||
label: "Tutorials",
|
||||
|
|
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
|||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -19,8 +18,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Enterprise Proxy Util Endpoints
|
||||
from litellm._logging import verbose_logger
|
||||
import collections
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
|
@ -18,26 +19,33 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
|||
return response
|
||||
|
||||
|
||||
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
async def ui_get_spend_by_tags(start_date: str, end_date: str, prisma_client):
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
DATE(s."startTime") AS spend_date,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs" s
|
||||
WHERE s."startTime" >= current_date - interval '30 days'
|
||||
WHERE
|
||||
DATE(s."startTime") >= $1::date
|
||||
AND DATE(s."startTime") <= $2::date
|
||||
GROUP BY individual_request_tag, spend_date
|
||||
ORDER BY spend_date;
|
||||
ORDER BY spend_date
|
||||
LIMIT 100;
|
||||
"""
|
||||
response = await prisma_client.db.query_raw(
|
||||
sql_query,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
# print("tags - spend")
|
||||
# print(response)
|
||||
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
|
||||
total_spend_per_tag = collections.defaultdict(float)
|
||||
total_requests_per_tag = collections.defaultdict(int)
|
||||
total_spend_per_tag: collections.defaultdict = collections.defaultdict(float)
|
||||
total_requests_per_tag: collections.defaultdict = collections.defaultdict(int)
|
||||
for row in response:
|
||||
tag_name = row["individual_request_tag"]
|
||||
tag_spend = row["total_spend"]
|
||||
|
@ -49,15 +57,18 @@ async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=Non
|
|||
# convert to ui format
|
||||
ui_tags = []
|
||||
for tag in sorted_tags:
|
||||
current_spend = tag[1]
|
||||
if current_spend is not None and isinstance(current_spend, float):
|
||||
current_spend = round(current_spend, 4)
|
||||
ui_tags.append(
|
||||
{
|
||||
"name": tag[0],
|
||||
"value": tag[1],
|
||||
"spend": current_spend,
|
||||
"log_count": total_requests_per_tag[tag[0]],
|
||||
}
|
||||
)
|
||||
|
||||
return {"top_10_tags": ui_tags}
|
||||
return {"spend_per_tag": ui_tags}
|
||||
|
||||
|
||||
async def view_spend_logs_from_clickhouse(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
### Hide pydantic namespace conflict warnings globally ###
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
|
@ -14,7 +15,9 @@ from litellm.proxy._types import (
|
|||
import httpx
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
|
||||
if litellm_mode == "DEV":
|
||||
dotenv.load_dotenv()
|
||||
#############################################
|
||||
if set_verbose == True:
|
||||
_turn_on_debug()
|
||||
|
@ -24,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_custom_logger_compatible_callbacks: list = ["openmeter"]
|
||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
|
||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
_langfuse_default_tags: Optional[
|
||||
List[
|
||||
Literal[
|
||||
|
@ -70,6 +73,7 @@ azure_key: Optional[str] = None
|
|||
anthropic_key: Optional[str] = None
|
||||
replicate_key: Optional[str] = None
|
||||
cohere_key: Optional[str] = None
|
||||
clarifai_key: Optional[str] = None
|
||||
maritalk_key: Optional[str] = None
|
||||
ai21_key: Optional[str] = None
|
||||
ollama_key: Optional[str] = None
|
||||
|
@ -100,6 +104,9 @@ blocked_user_list: Optional[Union[str, List]] = None
|
|||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||
##################
|
||||
### PREVIEW FEATURES ###
|
||||
enable_preview_features: bool = False
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
|
@ -214,6 +221,7 @@ max_end_user_budget: Optional[float] = None
|
|||
#### RELIABILITY ####
|
||||
request_timeout: Optional[float] = 6000
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
default_fallbacks: Optional[List] = None
|
||||
fallbacks: Optional[List] = None
|
||||
context_window_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 0
|
||||
|
@ -400,6 +408,73 @@ replicate_models: List = [
|
|||
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
||||
]
|
||||
|
||||
clarifai_models: List = [
|
||||
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
|
||||
"clarifai/gcp.generate.gemma-1_1-7b-it",
|
||||
"clarifai/mistralai.completion.mixtral-8x22B",
|
||||
"clarifai/cohere.generate.command-r-plus",
|
||||
"clarifai/databricks.drbx.dbrx-instruct",
|
||||
"clarifai/mistralai.completion.mistral-large",
|
||||
"clarifai/mistralai.completion.mistral-medium",
|
||||
"clarifai/mistralai.completion.mistral-small",
|
||||
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
|
||||
"clarifai/gcp.generate.gemma-2b-it",
|
||||
"clarifai/gcp.generate.gemma-7b-it",
|
||||
"clarifai/deci.decilm.deciLM-7B-instruct",
|
||||
"clarifai/mistralai.completion.mistral-7B-Instruct",
|
||||
"clarifai/gcp.generate.gemini-pro",
|
||||
"clarifai/anthropic.completion.claude-v1",
|
||||
"clarifai/anthropic.completion.claude-instant-1_2",
|
||||
"clarifai/anthropic.completion.claude-instant",
|
||||
"clarifai/anthropic.completion.claude-v2",
|
||||
"clarifai/anthropic.completion.claude-2_1",
|
||||
"clarifai/meta.Llama-2.codeLlama-70b-Python",
|
||||
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
|
||||
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
|
||||
"clarifai/meta.Llama-2.llama2-7b-chat",
|
||||
"clarifai/meta.Llama-2.llama2-13b-chat",
|
||||
"clarifai/meta.Llama-2.llama2-70b-chat",
|
||||
"clarifai/openai.chat-completion.gpt-4-turbo",
|
||||
"clarifai/microsoft.text-generation.phi-2",
|
||||
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
|
||||
"clarifai/upstage.solar.solar-10_7b-instruct",
|
||||
"clarifai/openchat.openchat.openchat-3_5-1210",
|
||||
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
|
||||
"clarifai/gcp.generate.text-bison",
|
||||
"clarifai/meta.Llama-2.llamaGuard-7b",
|
||||
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
|
||||
"clarifai/openai.chat-completion.GPT-4",
|
||||
"clarifai/openai.chat-completion.GPT-3_5-turbo",
|
||||
"clarifai/ai21.complete.Jurassic2-Grande",
|
||||
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
|
||||
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
|
||||
"clarifai/ai21.complete.Jurassic2-Jumbo",
|
||||
"clarifai/ai21.complete.Jurassic2-Large",
|
||||
"clarifai/cohere.generate.cohere-generate-command",
|
||||
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
|
||||
"clarifai/wizardlm.generate.wizardLM-70B",
|
||||
"clarifai/tiiuae.falcon.falcon-40b-instruct",
|
||||
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
|
||||
"clarifai/gcp.generate.code-gecko",
|
||||
"clarifai/gcp.generate.code-bison",
|
||||
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
|
||||
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
|
||||
"clarifai/wizardlm.generate.wizardLM-13B",
|
||||
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
|
||||
"clarifai/wizardlm.generate.wizardCoder-15B",
|
||||
"clarifai/microsoft.text-generation.phi-1_5",
|
||||
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
|
||||
"clarifai/bigcode.code.StarCoder",
|
||||
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
|
||||
"clarifai/mosaicml.mpt.mpt-7b-instruct",
|
||||
"clarifai/anthropic.completion.claude-3-opus",
|
||||
"clarifai/anthropic.completion.claude-3-sonnet",
|
||||
"clarifai/gcp.generate.gemini-1_5-pro",
|
||||
"clarifai/gcp.generate.imagen-2",
|
||||
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
|
||||
]
|
||||
|
||||
|
||||
huggingface_models: List = [
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
|
@ -505,6 +580,7 @@ provider_list: List = [
|
|||
"text-completion-openai",
|
||||
"cohere",
|
||||
"cohere_chat",
|
||||
"clarifai",
|
||||
"anthropic",
|
||||
"replicate",
|
||||
"huggingface",
|
||||
|
@ -537,6 +613,7 @@ provider_list: List = [
|
|||
"xinference",
|
||||
"fireworks_ai",
|
||||
"watsonx",
|
||||
"triton",
|
||||
"predibase",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
@ -654,6 +731,7 @@ from .llms.predibase import PredibaseConfig
|
|||
from .llms.anthropic_text import AnthropicTextConfig
|
||||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere import CohereConfig
|
||||
from .llms.clarifai import ClarifaiConfig
|
||||
from .llms.ai21 import AI21Config
|
||||
from .llms.together_ai import TogetherAIConfig
|
||||
from .llms.cloudflare import CloudflareConfig
|
||||
|
@ -668,6 +746,7 @@ from .llms.sagemaker import SagemakerConfig
|
|||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||
from .llms.bedrock import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
@ -679,7 +758,7 @@ from .llms.bedrock import (
|
|||
AmazonMistralConfig,
|
||||
AmazonBedrockGlobalConfig,
|
||||
)
|
||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
|
||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||
from .llms.watsonx import IBMWatsonXAIConfig
|
||||
from .main import * # type: ignore
|
||||
|
|
|
@ -373,11 +373,12 @@ class RedisCache(BaseCache):
|
|||
print_verbose(
|
||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||
)
|
||||
json_cache_value = json.dumps(cache_value)
|
||||
# Set the value with a TTL if it's provided.
|
||||
if ttl is not None:
|
||||
pipe.setex(cache_key, ttl, json.dumps(cache_value))
|
||||
pipe.setex(cache_key, ttl, json_cache_value)
|
||||
else:
|
||||
pipe.set(cache_key, json.dumps(cache_value))
|
||||
pipe.set(cache_key, json_cache_value)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
|
||||
|
@ -810,9 +811,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = litellm.embedding(
|
||||
|
@ -847,9 +846,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = litellm.embedding(
|
||||
|
@ -909,9 +906,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
|
@ -964,9 +959,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
|
@ -1448,7 +1441,7 @@ class DualCache(BaseCache):
|
|||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -1491,13 +1484,14 @@ class Cache:
|
|||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
disk_cache_dir=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local".
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local".
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
|
@ -1543,6 +1537,8 @@ class Cache:
|
|||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "disk":
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
|
@ -1914,8 +1910,86 @@ class Cache:
|
|||
await self.cache.disconnect()
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
import diskcache as dc
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose("DiskCache: set_cache")
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[Literal["local", "redis", "s3"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -1944,7 +2018,7 @@ def enable_cache(
|
|||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local".
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
|
@ -1980,7 +2054,7 @@ def enable_cache(
|
|||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -2009,7 +2083,7 @@ def update_cache(
|
|||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local".
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
|
|
|
@ -9,55 +9,64 @@
|
|||
|
||||
## LiteLLM versions of the OpenAI Exception Types
|
||||
|
||||
from openai import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
NotFoundError,
|
||||
RateLimitError,
|
||||
APIStatusError,
|
||||
OpenAIError,
|
||||
APIError,
|
||||
APITimeoutError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
PermissionDeniedError,
|
||||
)
|
||||
import openai
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AuthenticationError(AuthenticationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 401
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(NotFoundError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
class NotFoundError(openai.NotFoundError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 404
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class BadRequestError(BadRequestError): # type: ignore
|
||||
class BadRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self, message, model, llm_provider, response: Optional[httpx.Response] = None
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
|
@ -69,19 +78,29 @@ class BadRequestError(BadRequestError): # type: ignore
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 422
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class Timeout(APITimeoutError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
class Timeout(openai.APITimeoutError): # type: ignore
|
||||
def __init__(
|
||||
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None
|
||||
):
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(
|
||||
request=request
|
||||
|
@ -90,29 +109,46 @@ class Timeout(APITimeoutError): # type: ignore
|
|||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
||||
# custom function to convert to str
|
||||
def __str__(self):
|
||||
return str(self.message)
|
||||
|
||||
|
||||
class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 403
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class RateLimitError(RateLimitError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
class RateLimitError(openai.RateLimitError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 429
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.modle = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -120,11 +156,19 @@ class RateLimitError(RateLimitError): # type: ignore
|
|||
|
||||
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
||||
class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
|
@ -135,11 +179,19 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
|
||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
|
@ -148,51 +200,77 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class ServiceUnavailableError(APIStatusError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 503
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(APIError): # type: ignore
|
||||
class APIError(openai.APIError): # type: ignore
|
||||
def __init__(
|
||||
self, status_code, message, llm_provider, model, request: httpx.Request
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIConnectionError(APIConnectionError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, request: httpx.Request):
|
||||
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(message=self.message, request=request)
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model):
|
||||
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
||||
def __init__(
|
||||
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(response=response, body=None, message=message)
|
||||
|
||||
|
||||
class OpenAIError(OpenAIError): # type: ignore
|
||||
class OpenAIError(openai.OpenAIError): # type: ignore
|
||||
def __init__(self, original_exception):
|
||||
self.status_code = original_exception.http_status
|
||||
super().__init__(
|
||||
|
@ -214,7 +292,7 @@ class BudgetExceededError(Exception):
|
|||
|
||||
|
||||
## DEPRECATED ##
|
||||
class InvalidRequestError(BadRequestError): # type: ignore
|
||||
class InvalidRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -18,8 +16,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
179
litellm/integrations/lago.py
Normal file
179
litellm/integrations/lago.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# What is this?
|
||||
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
||||
|
||||
import dotenv, os, json
|
||||
import litellm
|
||||
import traceback, httpx
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
import uuid
|
||||
from typing import Optional, Literal
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class LagoLogger(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment()
|
||||
self.async_http_handler = AsyncHTTPHandler()
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self):
|
||||
"""
|
||||
Expects
|
||||
LAGO_API_BASE,
|
||||
LAGO_API_KEY,
|
||||
LAGO_API_EVENT_CODE,
|
||||
|
||||
Optional:
|
||||
LAGO_API_CHARGE_BY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if os.getenv("LAGO_API_KEY", None) is None:
|
||||
missing_keys.append("LAGO_API_KEY")
|
||||
|
||||
if os.getenv("LAGO_API_BASE", None) is None:
|
||||
missing_keys.append("LAGO_API_BASE")
|
||||
|
||||
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
|
||||
missing_keys.append("LAGO_API_EVENT_CODE")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
dt = get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
|
||||
if (
|
||||
isinstance(response_obj, litellm.ModelResponse)
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
) and hasattr(response_obj, "usage"):
|
||||
usage = {
|
||||
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
|
||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||
external_customer_id: Optional[str] = None
|
||||
|
||||
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
|
||||
os.environ["LAGO_API_CHARGE_BY"], str
|
||||
):
|
||||
if os.environ["LAGO_API_CHARGE_BY"] in [
|
||||
"end_user_id",
|
||||
"user_id",
|
||||
"team_id",
|
||||
]:
|
||||
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
|
||||
else:
|
||||
raise Exception("invalid LAGO_API_CHARGE_BY set")
|
||||
|
||||
if charge_by == "end_user_id":
|
||||
external_customer_id = end_user_id
|
||||
elif charge_by == "team_id":
|
||||
external_customer_id = team_id
|
||||
elif charge_by == "user_id":
|
||||
external_customer_id = user_id
|
||||
|
||||
if external_customer_id is None:
|
||||
raise Exception("External Customer ID is not set")
|
||||
|
||||
return {
|
||||
"event": {
|
||||
"transaction_id": str(uuid.uuid4()),
|
||||
"external_customer_id": external_customer_id,
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {"model": model, "response_cost": cost, **usage},
|
||||
}
|
||||
}
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
|
||||
_url
|
||||
)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try:
|
||||
response = await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if response is not None and hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import copy
|
||||
import traceback
|
||||
from packaging.version import Version
|
||||
|
@ -262,7 +260,23 @@ class LangFuseLogger:
|
|||
|
||||
try:
|
||||
tags = []
|
||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
||||
try:
|
||||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = copy.deepcopy(value)
|
||||
metadata = new_metadata
|
||||
|
||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
|
@ -307,6 +321,9 @@ class LangFuseLogger:
|
|||
trace_id = clean_metadata.pop("trace_id", None)
|
||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||
debug = clean_metadata.pop("debug_langfuse", None)
|
||||
mask_input = clean_metadata.pop("mask_input", False)
|
||||
mask_output = clean_metadata.pop("mask_output", False)
|
||||
|
||||
if trace_name is None and existing_trace_id is None:
|
||||
# just log `litellm-{call_type}` as the trace name
|
||||
|
@ -334,18 +351,19 @@ class LangFuseLogger:
|
|||
|
||||
# Special keys that are found in the function arguments and not the metadata
|
||||
if "input" in update_trace_keys:
|
||||
trace_params["input"] = input
|
||||
trace_params["input"] = input if not mask_input else "redacted-by-litellm"
|
||||
if "output" in update_trace_keys:
|
||||
trace_params["output"] = output
|
||||
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||
else: # don't overwrite an existing trace
|
||||
trace_params = {
|
||||
"id": trace_id,
|
||||
"name": trace_name,
|
||||
"session_id": session_id,
|
||||
"input": input,
|
||||
"input": input if not mask_input else "redacted-by-litellm",
|
||||
"version": clean_metadata.pop(
|
||||
"trace_version", clean_metadata.get("version", None)
|
||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
"user_id": user_id,
|
||||
}
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
|
@ -357,7 +375,14 @@ class LangFuseLogger:
|
|||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = output
|
||||
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||
|
||||
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||
if "metadata" in trace_params:
|
||||
# log the raw_metadata in the trace
|
||||
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
||||
else:
|
||||
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
print_verbose(f"trace: {cost}")
|
||||
|
@ -409,7 +434,6 @@ class LangFuseLogger:
|
|||
"url": url,
|
||||
"headers": clean_headers,
|
||||
}
|
||||
|
||||
trace = self.Langfuse.trace(**trace_params)
|
||||
|
||||
generation_id = None
|
||||
|
@ -441,8 +465,8 @@ class LangFuseLogger:
|
|||
"end_time": end_time,
|
||||
"model": kwargs["model"],
|
||||
"model_parameters": optional_params,
|
||||
"input": input,
|
||||
"output": output,
|
||||
"input": input if not mask_input else "redacted-by-litellm",
|
||||
"output": output if not mask_output else "redacted-by-litellm",
|
||||
"usage": usage,
|
||||
"metadata": clean_metadata,
|
||||
"level": level,
|
||||
|
@ -450,7 +474,29 @@ class LangFuseLogger:
|
|||
}
|
||||
|
||||
if supports_prompt:
|
||||
generation_params["prompt"] = clean_metadata.pop("prompt", None)
|
||||
user_prompt = clean_metadata.pop("prompt", None)
|
||||
if user_prompt is None:
|
||||
pass
|
||||
elif isinstance(user_prompt, dict):
|
||||
from langfuse.model import (
|
||||
TextPromptClient,
|
||||
ChatPromptClient,
|
||||
Prompt_Text,
|
||||
Prompt_Chat,
|
||||
)
|
||||
|
||||
if user_prompt.get("type", "") == "chat":
|
||||
_prompt_chat = Prompt_Chat(**user_prompt)
|
||||
generation_params["prompt"] = ChatPromptClient(
|
||||
prompt=_prompt_chat
|
||||
)
|
||||
elif user_prompt.get("type", "") == "text":
|
||||
_prompt_text = Prompt_Text(**user_prompt)
|
||||
generation_params["prompt"] = TextPromptClient(
|
||||
prompt=_prompt_text
|
||||
)
|
||||
else:
|
||||
generation_params["prompt"] = user_prompt
|
||||
|
||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||
generation_params["status_message"] = output
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os # type: ignore
|
||||
import requests # type: ignore
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import asyncio
|
||||
import types
|
||||
|
|
|
@ -2,14 +2,10 @@
|
|||
# On success + failure, log events to lunary.ai
|
||||
from datetime import datetime, timezone
|
||||
import traceback
|
||||
import dotenv
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import packaging
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
|
@ -18,13 +14,33 @@ def parse_usage(usage):
|
|||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
def parse_tool_calls(tool_calls):
|
||||
if tool_calls is None:
|
||||
return None
|
||||
|
||||
def clean_tool_call(tool_call):
|
||||
|
||||
serialized = {
|
||||
"type": tool_call.type,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
def clean_message(message):
|
||||
# if is strin, return as is
|
||||
# if is string, return as is
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
|
@ -38,9 +54,7 @@ def parse_messages(input):
|
|||
|
||||
# Only add tool_calls and function_call to res if they are set
|
||||
if message.get("tool_calls"):
|
||||
serialized["tool_calls"] = message.get("tool_calls")
|
||||
if message.get("function_call"):
|
||||
serialized["function_call"] = message.get("function_call")
|
||||
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||
|
||||
return serialized
|
||||
|
||||
|
@ -62,14 +76,16 @@ class LunaryLogger:
|
|||
version = importlib.metadata.version("lunary")
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||
print(
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
except ImportError:
|
||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
||||
print( # noqa
|
||||
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||
) # noqa
|
||||
raise ImportError
|
||||
|
||||
def log_event(
|
||||
|
@ -93,8 +109,13 @@ class LunaryLogger:
|
|||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
if optional_params:
|
||||
# merge into extra
|
||||
extra = {**extra, **optional_params}
|
||||
|
||||
tags = litellm_params.pop("tags", None) or []
|
||||
|
||||
if extra:
|
||||
|
@ -104,7 +125,7 @@ class LunaryLogger:
|
|||
|
||||
# keep only serializable types
|
||||
for param, value in extra.items():
|
||||
if not isinstance(value, (str, int, bool, float)):
|
||||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||
try:
|
||||
extra[param] = str(value)
|
||||
except:
|
||||
|
@ -140,7 +161,7 @@ class LunaryLogger:
|
|||
metadata=metadata,
|
||||
runtime="litellm",
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
params=extra,
|
||||
)
|
||||
|
||||
self.lunary_client.track_event(
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os, json
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm, threading
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
|
@ -14,7 +12,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|||
import datetime
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from datetime import datetime as dt, timedelta
|
||||
from datetime import datetime as dt, timedelta, timezone
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import random
|
||||
|
||||
|
@ -33,7 +31,10 @@ class LiteLLMBase(BaseModel):
|
|||
|
||||
|
||||
class SlackAlertingArgs(LiteLLMBase):
|
||||
daily_report_frequency: int = 12 * 60 * 60 # 12 hours
|
||||
default_daily_report_frequency: int = 12 * 60 * 60 # 12 hours
|
||||
daily_report_frequency: int = int(
|
||||
os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency)
|
||||
)
|
||||
report_check_interval: int = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
|
@ -78,8 +79,7 @@ class SlackAlerting(CustomLogger):
|
|||
internal_usage_cache: Optional[DualCache] = None,
|
||||
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: Optional[
|
||||
List[
|
||||
alert_types: List[
|
||||
Literal[
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
|
@ -88,7 +88,6 @@ class SlackAlerting(CustomLogger):
|
|||
"db_exceptions",
|
||||
"daily_reports",
|
||||
]
|
||||
]
|
||||
] = [
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
|
@ -242,6 +241,8 @@ class SlackAlerting(CustomLogger):
|
|||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
if litellm.turn_off_message_logging:
|
||||
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
|
||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||
if time_difference_float > self.alerting_threshold:
|
||||
|
@ -348,8 +349,9 @@ class SlackAlerting(CustomLogger):
|
|||
|
||||
all_none = True
|
||||
for val in combined_metrics_values:
|
||||
if val is not None:
|
||||
if val is not None and val > 0:
|
||||
all_none = False
|
||||
break
|
||||
|
||||
if all_none:
|
||||
return False
|
||||
|
@ -367,12 +369,15 @@ class SlackAlerting(CustomLogger):
|
|||
for value in failed_request_values
|
||||
]
|
||||
|
||||
## Get the indices of top 5 keys with the highest numerical values (ignoring None values)
|
||||
## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values)
|
||||
top_5_failed = sorted(
|
||||
range(len(replaced_failed_values)),
|
||||
key=lambda i: replaced_failed_values[i],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
top_5_failed = [
|
||||
index for index in top_5_failed if replaced_failed_values[index] > 0
|
||||
]
|
||||
|
||||
# find top 5 slowest
|
||||
# Replace None values with a placeholder value (-1 in this case)
|
||||
|
@ -382,17 +387,22 @@ class SlackAlerting(CustomLogger):
|
|||
for value in latency_values
|
||||
]
|
||||
|
||||
# Get the indices of top 5 values with the highest numerical values (ignoring None values)
|
||||
# Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values)
|
||||
top_5_slowest = sorted(
|
||||
range(len(replaced_slowest_values)),
|
||||
key=lambda i: replaced_slowest_values[i],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
top_5_slowest = [
|
||||
index for index in top_5_slowest if replaced_slowest_values[index] > 0
|
||||
]
|
||||
|
||||
# format alert -> return the litellm model name + api base
|
||||
message = f"\n\nHere are today's key metrics 📈: \n\n"
|
||||
|
||||
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n"
|
||||
message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
|
||||
if not top_5_failed:
|
||||
message += "\tNone\n"
|
||||
for i in range(len(top_5_failed)):
|
||||
key = failed_request_keys[top_5_failed[i]].split(":")[0]
|
||||
_deployment = router.get_model_info(key)
|
||||
|
@ -412,7 +422,9 @@ class SlackAlerting(CustomLogger):
|
|||
value = replaced_failed_values[top_5_failed[i]]
|
||||
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
|
||||
|
||||
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n"
|
||||
message += "\n\n*😅 Top Slowest Deployments:*\n\n"
|
||||
if not top_5_slowest:
|
||||
message += "\tNone\n"
|
||||
for i in range(len(top_5_slowest)):
|
||||
key = latency_keys[top_5_slowest[i]].split(":")[0]
|
||||
_deployment = router.get_model_info(key)
|
||||
|
@ -464,6 +476,11 @@ class SlackAlerting(CustomLogger):
|
|||
messages = messages[:100]
|
||||
except:
|
||||
messages = ""
|
||||
|
||||
if litellm.turn_off_message_logging:
|
||||
messages = (
|
||||
"Message not logged. `litellm.turn_off_message_logging=True`."
|
||||
)
|
||||
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
||||
else:
|
||||
request_info = ""
|
||||
|
@ -814,14 +831,6 @@ Model Info:
|
|||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
)
|
||||
)
|
||||
if "llm_exceptions" in self.alert_types:
|
||||
original_exception = kwargs.get("exception", None)
|
||||
|
||||
await self.send_alert(
|
||||
message="LLM API Failure - " + str(original_exception),
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
)
|
||||
|
||||
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||
"""
|
||||
|
@ -844,15 +853,22 @@ Model Info:
|
|||
value=_current_time,
|
||||
)
|
||||
else:
|
||||
# check if current time - interval >= time last sent
|
||||
delta = current_time - timedelta(
|
||||
seconds=self.alerting_args.daily_report_frequency
|
||||
)
|
||||
|
||||
# Check if current time - interval >= time last sent
|
||||
delta_naive = timedelta(seconds=self.alerting_args.daily_report_frequency)
|
||||
if isinstance(report_sent, str):
|
||||
report_sent = dt.fromisoformat(report_sent)
|
||||
|
||||
if delta >= report_sent:
|
||||
# Ensure report_sent is an aware datetime object
|
||||
if report_sent.tzinfo is None:
|
||||
report_sent = report_sent.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Calculate delta as an aware datetime object with the same timezone as report_sent
|
||||
delta = report_sent - delta_naive
|
||||
|
||||
current_time_utc = current_time.astimezone(timezone.utc)
|
||||
delta_utc = delta.astimezone(timezone.utc)
|
||||
|
||||
if current_time_utc >= delta_utc:
|
||||
# Sneak in the reporting logic here
|
||||
await self.send_daily_reports(router=llm_router)
|
||||
# Also, don't forget to update the report_sent time after sending the report!
|
||||
|
@ -885,3 +901,99 @@ Model Info:
|
|||
) # shuffle to prevent collisions
|
||||
await asyncio.sleep(interval)
|
||||
return
|
||||
|
||||
async def send_weekly_spend_report(self):
|
||||
""" """
|
||||
try:
|
||||
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
|
||||
|
||||
todays_date = datetime.datetime.now().date()
|
||||
week_before = todays_date - datetime.timedelta(days=7)
|
||||
|
||||
weekly_spend_per_team, weekly_spend_per_tag = (
|
||||
await _get_spend_report_for_time_range(
|
||||
start_date=week_before.strftime("%Y-%m-%d"),
|
||||
end_date=todays_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
)
|
||||
|
||||
_weekly_spend_message = f"*💸 Weekly Spend Report for `{week_before.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` *\n"
|
||||
|
||||
if weekly_spend_per_team is not None:
|
||||
_weekly_spend_message += "\n*Team Spend Report:*\n"
|
||||
for spend in weekly_spend_per_team:
|
||||
_team_spend = spend["total_spend"]
|
||||
_team_spend = float(_team_spend)
|
||||
# round to 4 decimal places
|
||||
_team_spend = round(_team_spend, 4)
|
||||
_weekly_spend_message += (
|
||||
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
||||
)
|
||||
|
||||
if weekly_spend_per_tag is not None:
|
||||
_weekly_spend_message += "\n*Tag Spend Report:*\n"
|
||||
for spend in weekly_spend_per_tag:
|
||||
_tag_spend = spend["total_spend"]
|
||||
_tag_spend = float(_tag_spend)
|
||||
# round to 4 decimal places
|
||||
_tag_spend = round(_tag_spend, 4)
|
||||
_weekly_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
||||
|
||||
await self.send_alert(
|
||||
message=_weekly_spend_message,
|
||||
level="Low",
|
||||
alert_type="daily_reports",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||
|
||||
async def send_monthly_spend_report(self):
|
||||
""" """
|
||||
try:
|
||||
from calendar import monthrange
|
||||
|
||||
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
|
||||
|
||||
todays_date = datetime.datetime.now().date()
|
||||
first_day_of_month = todays_date.replace(day=1)
|
||||
_, last_day_of_month = monthrange(todays_date.year, todays_date.month)
|
||||
last_day_of_month = first_day_of_month + datetime.timedelta(
|
||||
days=last_day_of_month - 1
|
||||
)
|
||||
|
||||
monthly_spend_per_team, monthly_spend_per_tag = (
|
||||
await _get_spend_report_for_time_range(
|
||||
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
||||
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
||||
)
|
||||
)
|
||||
|
||||
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
||||
|
||||
if monthly_spend_per_team is not None:
|
||||
_spend_message += "\n*Team Spend Report:*\n"
|
||||
for spend in monthly_spend_per_team:
|
||||
_team_spend = spend["total_spend"]
|
||||
_team_spend = float(_team_spend)
|
||||
# round to 4 decimal places
|
||||
_team_spend = round(_team_spend, 4)
|
||||
_spend_message += (
|
||||
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
||||
)
|
||||
|
||||
if monthly_spend_per_tag is not None:
|
||||
_spend_message += "\n*Tag Spend Report:*\n"
|
||||
for spend in monthly_spend_per_tag:
|
||||
_tag_spend = spend["total_spend"]
|
||||
_tag_spend = float(_tag_spend)
|
||||
# round to 4 decimal places
|
||||
_tag_spend = round(_tag_spend, 4)
|
||||
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
||||
|
||||
await self.send_alert(
|
||||
message=_spend_message,
|
||||
level="Low",
|
||||
alert_type="daily_reports",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm
|
||||
|
|
|
@ -21,11 +21,11 @@ try:
|
|||
# contains a (known) object attribute
|
||||
object: Literal["chat.completion", "edit", "text_completion"]
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
... # pragma: no cover
|
||||
def __getitem__(self, key: K) -> V: ... # noqa
|
||||
|
||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
||||
... # pragma: no cover
|
||||
def get( # noqa
|
||||
self, key: K, default: Optional[V] = None
|
||||
) -> Optional[V]: ... # pragma: no cover
|
||||
|
||||
class OpenAIRequestResponseResolver:
|
||||
def __call__(
|
||||
|
@ -173,12 +173,11 @@ except:
|
|||
|
||||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
import os
|
||||
import requests
|
||||
import requests
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from typing import Callable, Optional, List, Union
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> CustomStreamWrapper:
|
||||
"""
|
||||
Return stream object for tool-calling + streaming
|
||||
"""
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise AnthropicError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
raise AnthropicError(
|
||||
status_code=422,
|
||||
message="Unprocessable response object - {}".format(response.text),
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model,
|
||||
response,
|
||||
model_response,
|
||||
_is_function_call,
|
||||
stream,
|
||||
logging_obj,
|
||||
api_key,
|
||||
data,
|
||||
messages,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
):
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||
if _is_function_call and stream:
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
|
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
|
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
return self.process_response(
|
||||
if stream and _is_function_call:
|
||||
return self.process_streaming_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
return self.process_response(
|
||||
|
||||
if stream and _is_function_call:
|
||||
return self.process_streaming_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
|
|
|
@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
def _process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
|
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self.process_response(
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
|
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self.process_response(
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
|
|
|
@ -8,14 +8,16 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from typing import Callable, Optional, BinaryIO, List
|
||||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx # type: ignore
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
import uuid
|
||||
import os
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -105,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
|
|||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
# azure_client_params = {
|
||||
|
@ -126,6 +134,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
|||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||
|
||||
if azure_client_id is None or azure_tenant is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
req_token = httpx.post(
|
||||
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||
|
||||
if possible_azure_ad_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token not returned"
|
||||
)
|
||||
|
||||
return possible_azure_ad_token
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -137,6 +190,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
headers["api-key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
|
@ -189,6 +244,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if acompletion is True:
|
||||
|
@ -276,6 +334,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -351,6 +411,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
|
@ -422,6 +484,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -478,6 +542,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
|
@ -599,6 +665,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
## LOGGING
|
||||
|
@ -755,6 +823,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
|
@ -833,6 +903,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if max_retries is not None:
|
||||
|
|
|
@ -1,12 +1,32 @@
|
|||
## This is a template base class to be used for adding new LLM providers via API calls
|
||||
import litellm
|
||||
import httpx
|
||||
from typing import Optional
|
||||
import httpx, requests
|
||||
from typing import Optional, Union
|
||||
from litellm.utils import Logging
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: litellm.utils.ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> litellm.utils.ModelResponse:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
|
|
|
@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1",
|
||||
"eu-west-3",
|
||||
"eu-central-1",
|
||||
]
|
||||
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
|
@ -551,6 +561,7 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
|
@ -567,6 +578,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
|
@ -582,6 +594,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
|
@ -620,7 +633,38 @@ def init_bedrock_client(
|
|||
config = boto3.session.Config()
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts"
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
|
@ -755,6 +799,7 @@ def completion(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop("aws_bedrock_client", None)
|
||||
|
@ -769,6 +814,7 @@ def completion(
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
@ -1291,6 +1337,7 @@ def embedding(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1298,6 +1345,7 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
@ -1380,6 +1428,7 @@ def image_generation(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1387,6 +1436,7 @@ def image_generation(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
|
|
733
litellm/llms/bedrock_httpx.py
Normal file
733
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V0 - just covers cohere command-r support
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
List,
|
||||
Literal,
|
||||
Union,
|
||||
Any,
|
||||
TypedDict,
|
||||
Tuple,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
map_finish_reason,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
get_secret,
|
||||
Logging,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||
from litellm.types.llms.bedrock import *
|
||||
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
"""
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
preamble: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
p: Optional[float] = None
|
||||
k: Optional[float] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
return_prompt: Optional[bool] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
raw_prompting: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
documents: Optional[List[Document]] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
preamble: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
p: Optional[float] = None,
|
||||
k: Optional[float] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
return_prompt: Optional[bool] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
raw_prompting: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
||||
```
|
||||
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Accept: application/json' \
|
||||
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||
--data-raw '{
|
||||
"prompt": "Hi",
|
||||
"temperature": 0,
|
||||
"p": 0.9,
|
||||
"max_tokens": 4096
|
||||
}'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
import boto3
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
return sts_response["Credentials"]
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
prompt_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count",
|
||||
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||
)
|
||||
)
|
||||
completion_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count",
|
||||
len(
|
||||
encoding.encode(
|
||||
model_response.choices[0].message.content, # type: ignore
|
||||
disallowed_special=(),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
try:
|
||||
import boto3
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if stream is not None and stream == True:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
provider = model.split(".")[0]
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
data = json.dumps(_data)
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream == True:
|
||||
inference_params["stream"] = (
|
||||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
else:
|
||||
raise Exception("UNSUPPORTED PROVIDER")
|
||||
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
if stream is not None and stream == True:
|
||||
response = self.client.post(
|
||||
url=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
|
||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
return super().embedding(*args, **kwargs)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
from botocore.model import ServiceModel
|
||||
from botocore.loaders import Loader
|
||||
|
||||
loader = Loader()
|
||||
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
|
||||
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||
return bedrock_service_model.shape_for("ResponseStream")
|
||||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
for chunk in iterator:
|
||||
event_stream_buffer.add_data(chunk)
|
||||
for event in event_stream_buffer:
|
||||
message = self._parse_message_from_event(event)
|
||||
if message:
|
||||
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
async for chunk in iterator:
|
||||
event_stream_buffer.add_data(chunk)
|
||||
for event in event_stream_buffer:
|
||||
message = self._parse_message_from_event(event)
|
||||
if message:
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
||||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||
if response_dict["status_code"] != 200:
|
||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
328
litellm/llms/clarifai.py
Normal file
328
litellm/llms/clarifai.py
Normal file
|
@ -0,0 +1,328 @@
|
|||
import os, types, traceback
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage, Choices, Message, CustomStreamWrapper
|
||||
import litellm
|
||||
import httpx
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class ClarifaiError(Exception):
|
||||
def __init__(self, status_code, message, url):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=url
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
)
|
||||
|
||||
class ClarifaiConfig:
|
||||
"""
|
||||
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
|
||||
TODO fill in the details
|
||||
"""
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def completions_to_model(payload):
|
||||
# if payload["n"] != 1:
|
||||
# raise HTTPException(
|
||||
# status_code=422,
|
||||
# detail="Only one generation is supported. Please set candidate_count to 1.",
|
||||
# )
|
||||
|
||||
params = {}
|
||||
if temperature := payload.get("temperature"):
|
||||
params["temperature"] = temperature
|
||||
if max_tokens := payload.get("max_tokens"):
|
||||
params["max_tokens"] = max_tokens
|
||||
return {
|
||||
"inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
|
||||
"model": {"output_info": {"params": params}},
|
||||
}
|
||||
|
||||
def process_response(
|
||||
model,
|
||||
prompt,
|
||||
response,
|
||||
model_response,
|
||||
api_key,
|
||||
data,
|
||||
encoding,
|
||||
logging_obj
|
||||
):
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except Exception:
|
||||
raise ClarifaiError(
|
||||
message=response.text, status_code=response.status_code, url=model
|
||||
)
|
||||
# print(completion_response)
|
||||
try:
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["outputs"]):
|
||||
if len(item["data"]["text"]["raw"]) > 0:
|
||||
message_obj = Message(content=item["data"]["text"]["raw"])
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason="stop",
|
||||
index=idx + 1, #check
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
|
||||
except Exception as e:
|
||||
raise ClarifaiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||
)
|
||||
|
||||
# Calculate Usage
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
||||
)
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def convert_model_to_url(model: str, api_base: str):
|
||||
user_id, app_id, model_id = model.split(".")
|
||||
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
|
||||
|
||||
def get_prompt_model_name(url: str):
|
||||
clarifai_model_name = url.split("/")[-2]
|
||||
if "claude" in clarifai_model_name:
|
||||
return "anthropic", clarifai_model_name.replace("_", ".")
|
||||
if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
|
||||
return "", "meta-llama/llama-2-chat"
|
||||
else:
|
||||
return "", clarifai_model_name
|
||||
|
||||
async def async_completion(
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={}):
|
||||
|
||||
async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
acompletion=False,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
model = convert_model_to_url(model, api_base)
|
||||
prompt = " ".join(message["content"] for message in messages) # TODO
|
||||
|
||||
## Load Config
|
||||
config = litellm.ClarifaiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
):
|
||||
optional_params[k] = v
|
||||
|
||||
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
|
||||
if custom_llm_provider == "anthropic":
|
||||
prompt = prompt_factory(
|
||||
model=orig_model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
custom_llm_provider="clarifai"
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=orig_model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
# print(prompt); exit(0)
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
data = completions_to_model(data)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
if acompletion==True:
|
||||
return async_completion(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
model,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
# print(response.content); exit()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="clarifai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
|
||||
else:
|
||||
return process_response(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj)
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
|
@ -58,8 +58,15 @@ class AsyncHTTPHandler:
|
|||
|
||||
class HTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
self,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
concurrent_limit=1000,
|
||||
client: Optional[httpx.Client] = None,
|
||||
):
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_TIMEOUT
|
||||
|
||||
if client is None:
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
|
@ -68,6 +75,8 @@ class HTTPHandler:
|
|||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
def close(self):
|
||||
# Close the client when you're done with it
|
||||
|
@ -82,11 +91,15 @@ class HTTPHandler:
|
|||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
|||
from .base import BaseLLM
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable, Dict, List, Any
|
||||
from typing import Callable, Dict, List, Any, Literal, Tuple
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
import enum
|
||||
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
hf_task_list = [
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
hf_tasks = Literal[
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
|
||||
class HuggingfaceConfig:
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
|
||||
hf_task: Optional[hf_tasks] = (
|
||||
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||
)
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"echo",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
return optional_params
|
||||
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
|
@ -162,18 +227,21 @@ def read_tgi_conv_models():
|
|||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model):
|
||||
def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
if model.split("/")[0] in hf_task_list:
|
||||
split_model = model.split("/", 1)
|
||||
return split_model[0], split_model[1] # type: ignore
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
return "text-generation-inference"
|
||||
return "text-generation-inference", model
|
||||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
return "conversational", model
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
return "text-generation", model
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
return "text-generation-inference", model # default to tgi
|
||||
|
||||
|
||||
class Huggingface(BaseLLM):
|
||||
|
@ -202,7 +270,7 @@ class Huggingface(BaseLLM):
|
|||
self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
task: hf_tasks,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
|
@ -270,6 +338,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
elif task == "text-classification":
|
||||
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||
completion_response
|
||||
)
|
||||
else:
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
|
@ -332,7 +404,13 @@ class Huggingface(BaseLLM):
|
|||
exception_mapping_worked = False
|
||||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
task = get_hf_task_for_model(model)
|
||||
task, model = get_hf_task_for_model(model)
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||
)
|
||||
|
||||
print_verbose(f"{model}, {task}")
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
|
|||
inference_params.pop("return_full_text")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": ( # type: ignore
|
||||
True
|
||||
}
|
||||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
data["stream"] = ( # type: ignore
|
||||
True # type: ignore
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
)
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
|
|||
isinstance(completion_response, dict)
|
||||
and "error" in completion_response
|
||||
):
|
||||
print_verbose(f"completion error: {completion_response['error']}")
|
||||
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||
print_verbose(f"response.status_code: {response.status_code}")
|
||||
raise HuggingfaceError(
|
||||
message=completion_response["error"],
|
||||
message=completion_response["error"], # type: ignore
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return self.convert_to_model_response_object(
|
||||
|
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
|
|||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
task: hf_tasks,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
|
|
|
@ -300,7 +300,7 @@ def get_ollama_response(
|
|||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"] = response_json["message"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
||||
|
@ -484,7 +484,7 @@ async def ollama_acompletion(
|
|||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"] = response_json["message"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama_chat/" + data["model"]
|
||||
|
|
|
@ -53,6 +53,113 @@ class OpenAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/
|
||||
|
||||
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
||||
|
||||
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
||||
|
||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||
|
||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
|
||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||
|
||||
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
||||
"""
|
||||
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: Optional[bool] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def _map_tool_choice(self, tool_choice: str) -> str:
|
||||
if tool_choice == "auto" or tool_choice == "none":
|
||||
return tool_choice
|
||||
elif tool_choice == "required":
|
||||
return "any"
|
||||
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||
return "any"
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "tool_choice" and isinstance(value, str):
|
||||
optional_params["tool_choice"] = self._map_tool_choice(
|
||||
tool_choice=value
|
||||
)
|
||||
if param == "seed":
|
||||
optional_params["extra_body"] = {"random_seed": value}
|
||||
return optional_params
|
||||
|
||||
|
||||
class OpenAIConfig:
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
@ -1327,8 +1434,8 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
|
||||
thread_id, **message_data
|
||||
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
||||
thread_id, **message_data # type: ignore
|
||||
)
|
||||
|
||||
response_obj: Optional[OpenAIMessage] = None
|
||||
|
@ -1458,7 +1565,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll(
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise PredibaseError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
raise PredibaseError(message=response.text, status_code=422)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
message=str(completion_response["error"]),
|
||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
if acompletion == True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
return self.async_streaming(
|
||||
|
|
|
@ -1509,6 +1509,11 @@ def prompt_factory(
|
|||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "clarifai":
|
||||
if "claude" in model:
|
||||
return anthropic_pt(messages=messages)
|
||||
|
||||
elif custom_llm_provider == "perplexity":
|
||||
for message in messages:
|
||||
message.pop("name", None)
|
||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class TritonError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_base: str,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
|
||||
async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_text_response = response.text
|
||||
|
||||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
_outputs = _json_response["outputs"]
|
||||
_output_data = _outputs[0]["data"]
|
||||
_embedding_output = {
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": _output_data,
|
||||
}
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.data = [_embedding_output]
|
||||
|
||||
return model_response
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": curl_string,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
|
@ -198,6 +198,23 @@ class VertexAIConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||
"""
|
||||
return [
|
||||
"europe-central2",
|
||||
"europe-north1",
|
||||
"europe-southwest1",
|
||||
"europe-west1",
|
||||
"europe-west2",
|
||||
"europe-west3",
|
||||
"europe-west4",
|
||||
"europe-west6",
|
||||
"europe-west8",
|
||||
"europe-west9",
|
||||
]
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
|
@ -419,6 +436,7 @@ def completion(
|
|||
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||
import google.auth # type: ignore
|
||||
import proto # type: ignore
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
print_verbose(
|
||||
|
@ -605,9 +623,21 @@ def completion(
|
|||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
|
||||
# Check if it's a RepeatedComposite instance
|
||||
for key, val in function_call.args.items():
|
||||
if isinstance(
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||
):
|
||||
# If so, convert to list
|
||||
args_dict[key] = [v for v in val]
|
||||
else:
|
||||
args_dict[key] = val
|
||||
|
||||
try:
|
||||
args_str = json.dumps(args_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=422, message=str(e))
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
|
@ -810,6 +840,8 @@ def completion(
|
|||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
if isinstance(e, VertexAIError):
|
||||
raise e
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
|
@ -835,6 +867,8 @@ async def async_completion(
|
|||
Add support for acompletion calls for gemini-pro
|
||||
"""
|
||||
try:
|
||||
import proto # type: ignore
|
||||
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
@ -869,9 +903,21 @@ async def async_completion(
|
|||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
|
||||
# Check if it's a RepeatedComposite instance
|
||||
for key, val in function_call.args.items():
|
||||
if isinstance(
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||
):
|
||||
# If so, convert to list
|
||||
args_dict[key] = [v for v in val]
|
||||
else:
|
||||
args_dict[key] = val
|
||||
|
||||
try:
|
||||
args_str = json.dumps(args_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=422, message=str(e))
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
|
|
|
@ -1,12 +1,26 @@
|
|||
from enum import Enum
|
||||
import json, types, time # noqa: E401
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Any, Union, List
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
AsyncGenerator,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
Optional,
|
||||
Any,
|
||||
Union,
|
||||
List,
|
||||
ContextManager,
|
||||
AsyncContextManager,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from litellm.utils import ModelResponse, Usage, get_secret
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
@ -149,6 +163,15 @@ class IBMWatsonXAIConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
]
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
|
@ -188,11 +211,12 @@ class WatsonXAIEndpoint(str, Enum):
|
|||
)
|
||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||
"""
|
||||
|
@ -343,7 +367,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
if token is None and api_key is not None:
|
||||
# generate the auth token
|
||||
if print_verbose:
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = self.generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
|
@ -378,10 +402,11 @@ class IBMWatsonXAI(BaseLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
optional_params=None,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[float] = None,
|
||||
timeout=None,
|
||||
):
|
||||
"""
|
||||
Send a text generation request to the IBM Watsonx.ai API.
|
||||
|
@ -402,12 +427,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
|
||||
def process_text_request(request_params: dict) -> ModelResponse:
|
||||
with self._manage_response(
|
||||
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
def process_text_gen_response(json_resp: dict) -> ModelResponse:
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
|
@ -415,36 +440,70 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def process_stream_request(
|
||||
request_params: dict,
|
||||
def process_stream_response(
|
||||
stream_resp: Union[Iterator[str], AsyncIterator],
|
||||
) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
with self._manage_response(
|
||||
request_params,
|
||||
logging_obj=logging_obj,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
response = litellm.CustomStreamWrapper(
|
||||
resp.iter_lines(),
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
stream_resp,
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return response
|
||||
return streamwrapper
|
||||
|
||||
# create the function to manage the request to watsonx.ai
|
||||
self.request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = process_stream_response(resp.iter_lines())
|
||||
return streamwrapper
|
||||
|
||||
async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = process_stream_response(resp.aiter_lines())
|
||||
return streamwrapper
|
||||
|
||||
try:
|
||||
## Get the response from the model
|
||||
|
@ -455,10 +514,18 @@ class IBMWatsonXAI(BaseLLM):
|
|||
optional_params=optional_params,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if stream:
|
||||
return process_stream_request(req_params)
|
||||
if stream and (acompletion is True):
|
||||
# stream and async text generation
|
||||
return handle_stream_request_async(req_params)
|
||||
elif stream:
|
||||
# streaming text generation
|
||||
return handle_stream_request(req_params)
|
||||
elif (acompletion is True):
|
||||
# async text generation
|
||||
return handle_text_request_async(req_params)
|
||||
else:
|
||||
return process_text_request(req_params)
|
||||
# regular text generation
|
||||
return handle_text_request(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
@ -473,6 +540,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
aembedding=None,
|
||||
):
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
|
@ -507,9 +575,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||
# request = httpx.Request(
|
||||
# "POST", url, headers=headers, json=payload, params=request_params
|
||||
# )
|
||||
req_params = {
|
||||
"method": "POST",
|
||||
"url": url,
|
||||
|
@ -517,26 +582,50 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
with self._manage_response(
|
||||
req_params, logging_obj=logging_obj, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
request_manager = RequestManager(logging_obj)
|
||||
|
||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": result["embedding"],
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def handle_embedding(request_params: dict) -> ModelResponse:
|
||||
with request_manager.request(request_params, input=input) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
async def handle_aembedding(request_params: dict) -> ModelResponse:
|
||||
async with request_manager.async_request(request_params, input=input) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
try:
|
||||
if aembedding is True:
|
||||
return handle_embedding(req_params)
|
||||
else:
|
||||
return handle_aembedding(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def generate_iam_token(self, api_key=None, **params):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
@ -558,52 +647,144 @@ class IBMWatsonXAI(BaseLLM):
|
|||
self.token = iam_access_token
|
||||
return iam_access_token
|
||||
|
||||
@contextmanager
|
||||
def _manage_response(
|
||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||
api_params = self._get_api_params(params)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||
json_resp = resp.json()
|
||||
if not ids_only:
|
||||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
||||
request_manager = RequestManager(logging_obj=logging_obj)
|
||||
async with request_manager.request(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
with request_manager.async_request(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, logging_obj=None):
|
||||
self.logging_obj = logging_obj
|
||||
|
||||
def pre_call(
|
||||
self,
|
||||
request_params: dict,
|
||||
logging_obj: Any,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
request_str = (
|
||||
f"response = {request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params['json']},\n"
|
||||
f"\tjson={request_params.get('json')},\n"
|
||||
f")"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
self.logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params["json"],
|
||||
"complete_input_dict": request_params.get("json"),
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
try:
|
||||
if stream:
|
||||
resp = requests.request(
|
||||
**request_params,
|
||||
stream=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
else:
|
||||
resp = requests.request(**request_params)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
logging_obj.post_call(
|
||||
|
||||
def post_call(self, resp, request_params):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
self.logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params["json"],
|
||||
"complete_input_dict": request_params.get(
|
||||
"data", request_params.get("json")
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
self.pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
resp = requests.request(**request_params)
|
||||
if not resp.ok:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
self.pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(
|
||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||
),
|
||||
)
|
||||
# async_handler.client.verify = False
|
||||
if "json" in request_params:
|
||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp.status_code not in [200, 201]:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
140
litellm/main.py
140
litellm/main.py
|
@ -9,12 +9,13 @@
|
|||
|
||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any, Literal, Union, BinaryIO
|
||||
from typing_extensions import overload
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
import litellm
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
|
@ -47,6 +48,7 @@ from .llms import (
|
|||
ai21,
|
||||
sagemaker,
|
||||
bedrock,
|
||||
triton,
|
||||
huggingface_restapi,
|
||||
replicate,
|
||||
aleph_alpha,
|
||||
|
@ -56,6 +58,7 @@ from .llms import (
|
|||
ollama,
|
||||
ollama_chat,
|
||||
cloudflare,
|
||||
clarifai,
|
||||
cohere,
|
||||
cohere_chat,
|
||||
petals,
|
||||
|
@ -75,6 +78,8 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -103,7 +108,6 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
@ -112,6 +116,8 @@ azure_chat_completions = AzureChatCompletion()
|
|||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -254,7 +260,7 @@ async def acompletion(
|
|||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
custom_llm_provider = None
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||
completion_kwargs = {
|
||||
"model": model,
|
||||
|
@ -286,6 +292,7 @@ async def acompletion(
|
|||
"model_list": model_list,
|
||||
"acompletion": True, # assuming this is a required parameter
|
||||
}
|
||||
if custom_llm_provider is None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
|
@ -297,9 +304,6 @@ async def acompletion(
|
|||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
|
@ -321,6 +325,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "predibase"
|
||||
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -661,6 +666,7 @@ def completion(
|
|||
"supports_system_message",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
]
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
|
@ -668,20 +674,6 @@ def completion(
|
|||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
try:
|
||||
if base_url is not None:
|
||||
api_base = base_url
|
||||
|
@ -721,9 +713,18 @@ def completion(
|
|||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
|
||||
custom_llm_provider
|
||||
):
|
||||
timeout = timeout.read or 600 # default 10 min timeout
|
||||
elif not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
print_verbose(f"Registering model={model} in model cost map")
|
||||
litellm.register_model(
|
||||
{
|
||||
f"{custom_llm_provider}/{model}": {
|
||||
|
@ -845,6 +846,10 @@ def completion(
|
|||
proxy_server_request=proxy_server_request,
|
||||
preset_cache_key=preset_cache_key,
|
||||
no_log=no_log,
|
||||
input_cost_per_second=input_cost_per_second,
|
||||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_second=output_cost_per_second,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -1210,6 +1215,61 @@ def completion(
|
|||
)
|
||||
|
||||
response = model_response
|
||||
elif (
|
||||
"clarifai" in model
|
||||
or custom_llm_provider == "clarifai"
|
||||
or model in litellm.clarifai_models
|
||||
):
|
||||
clarifai_key = None
|
||||
clarifai_key = (
|
||||
api_key
|
||||
or litellm.clarifai_key
|
||||
or litellm.api_key
|
||||
or get_secret("CLARIFAI_API_KEY")
|
||||
or get_secret("CLARIFAI_API_TOKEN")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("CLARIFAI_API_BASE")
|
||||
or "https://api.clarifai.com/v2"
|
||||
)
|
||||
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
model_response = clarifai.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=clarifai_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=clarifai_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider == "anthropic":
|
||||
api_key = (
|
||||
|
@ -1919,6 +1979,24 @@ def completion(
|
|||
elif custom_llm_provider == "bedrock":
|
||||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
if "cohere" in model:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
else:
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -2622,6 +2700,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "voyage"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
|
@ -2779,6 +2858,7 @@ def embedding(
|
|||
"no-log",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -2955,6 +3035,23 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
@ -3662,6 +3759,7 @@ def image_generation(
|
|||
"cache",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
|
|
@ -9,6 +9,30 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"gpt-4o": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"gpt-4o-2024-05-13": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"gpt-4-turbo-preview": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -1086,6 +1110,36 @@
|
|||
"supports_tool_choice": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-flash-preview-0514": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0514": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000000625,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0215": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
|
@ -1331,6 +1385,24 @@
|
|||
"mode": "completion",
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-latest": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 32760,
|
||||
|
@ -1571,6 +1643,159 @@
|
|||
"litellm_provider": "replicate",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/microsoft/wizardlm-2-8x22b:nitro": {
|
||||
"max_tokens": 65536,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/google/gemini-pro-1.5": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"output_cost_per_token": 0.0000075,
|
||||
"input_cost_per_image": 0.00265,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/mistralai/mixtral-8x22b-instruct": {
|
||||
"max_tokens": 65536,
|
||||
"input_cost_per_token": 0.00000065,
|
||||
"output_cost_per_token": 0.00000065,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/cohere/command-r-plus": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/databricks/dbrx-instruct": {
|
||||
"max_tokens": 32768,
|
||||
"input_cost_per_token": 0.0000006,
|
||||
"output_cost_per_token": 0.0000006,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku": {
|
||||
"max_tokens": 200000,
|
||||
"input_cost_per_token": 0.00000025,
|
||||
"output_cost_per_token": 0.00000125,
|
||||
"input_cost_per_image": 0.0004,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/anthropic/claude-3-sonnet": {
|
||||
"max_tokens": 200000,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"input_cost_per_image": 0.0048,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/mistralai/mistral-large": {
|
||||
"max_tokens": 32000,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/cognitivecomputations/dolphin-mixtral-8x7b": {
|
||||
"max_tokens": 32769,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000005,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/google/gemini-pro-vision": {
|
||||
"max_tokens": 45875,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000375,
|
||||
"input_cost_per_image": 0.0025,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/fireworks/firellava-13b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0000002,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/meta-llama/llama-3-8b-instruct:free": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/meta-llama/llama-3-8b-instruct:extended": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.000000225,
|
||||
"output_cost_per_token": 0.00000225,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/meta-llama/llama-3-70b-instruct:nitro": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0000009,
|
||||
"output_cost_per_token": 0.0000009,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/meta-llama/llama-3-70b-instruct": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00000059,
|
||||
"output_cost_per_token": 0.00000079,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/openai/gpt-4o": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4o-2024-05-13": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4-vision-preview": {
|
||||
"max_tokens": 130000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"input_cost_per_image": 0.01445,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-3.5-turbo": {
|
||||
"max_tokens": 4095,
|
||||
"input_cost_per_token": 0.0000015,
|
||||
|
@ -1621,14 +1846,14 @@
|
|||
"tool_use_system_prompt_tokens": 395
|
||||
},
|
||||
"openrouter/google/palm-2-chat-bison": {
|
||||
"max_tokens": 8000,
|
||||
"max_tokens": 25804,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000005,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/google/palm-2-codechat-bison": {
|
||||
"max_tokens": 8000,
|
||||
"max_tokens": 20070,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000005,
|
||||
"litellm_provider": "openrouter",
|
||||
|
@ -1711,13 +1936,6 @@
|
|||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/meta-llama/llama-3-70b-instruct": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.0000008,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"j2-ultra": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
@ -2522,6 +2740,24 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"cohere.command-r-plus-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000030,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"cohere.command-r-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"cohere.embed-english-v3": {
|
||||
"max_tokens": 512,
|
||||
"max_input_tokens": 512,
|
||||
|
@ -2749,6 +2985,24 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/llama3": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/llama3:70b": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
@ -2758,6 +3012,42 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.2": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x7B-Instruct-v0.1": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x22B-Instruct-v0.1": {
|
||||
"max_tokens": 65536,
|
||||
"max_input_tokens": 65536,
|
||||
"max_output_tokens": 65536,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/codellama": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
|||
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/a1602eb39f799143.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
|
||||
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/f04e46b02318b660.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
|||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-5b257e1ab47d4b4a.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/a1602eb39f799143.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[25539,[\"936\",\"static/chunks/2f6dbc85-17d29013b8ff3da5.js\",\"566\",\"static/chunks/566-ccd699ab19124658.js\",\"931\",\"static/chunks/app/page-c804e862b63be987.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/a1602eb39f799143.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"K8KXTbmuI2ArWjjdMi2iq\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
||||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-c35c14c9afd091ec.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"2ASoJGxS-D4w-vat00xMy\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
|
@ -1,7 +1,7 @@
|
|||
2:I[77831,[],""]
|
||||
3:I[25539,["936","static/chunks/2f6dbc85-17d29013b8ff3da5.js","566","static/chunks/566-ccd699ab19124658.js","931","static/chunks/app/page-c804e862b63be987.js"],""]
|
||||
3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-c35c14c9afd091ec.js"],""]
|
||||
4:I[5613,[],""]
|
||||
5:I[31778,[],""]
|
||||
0:["K8KXTbmuI2ArWjjdMi2iq",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/a1602eb39f799143.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
0:["2ASoJGxS-D4w-vat00xMy",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
||||
1:null
|
||||
|
|
|
@ -16,4 +16,3 @@ router_settings:
|
|||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
|
||||
|
|
|
@ -52,8 +52,18 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLMRoutes(enum.Enum):
|
||||
openai_route_names: List = [
|
||||
"chat_completion",
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"audio_transcriptions",
|
||||
"moderations",
|
||||
"model_list", # OpenAI /v1/models route
|
||||
]
|
||||
openai_routes: List = [
|
||||
# chat completions
|
||||
"/engines/{model}/chat/completions",
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
|
@ -79,15 +89,23 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/v1/models",
|
||||
]
|
||||
|
||||
llm_utils_routes: List = ["utils/token_counter"]
|
||||
|
||||
info_routes: List = [
|
||||
"/key/info",
|
||||
"/team/info",
|
||||
"/team/list",
|
||||
"/user/info",
|
||||
"/model/info",
|
||||
"/v2/model/info",
|
||||
"/v2/key/info",
|
||||
]
|
||||
|
||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
master_key_only_routes: List = [
|
||||
"/global/spend/reset",
|
||||
]
|
||||
|
||||
sso_only_routes: List = [
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
|
@ -110,6 +128,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/team/new",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/team/list",
|
||||
"/team/info",
|
||||
"/team/block",
|
||||
"/team/unblock",
|
||||
|
@ -182,15 +201,32 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
|
||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||
admin_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["management_routes"]
|
||||
team_jwt_scope: str = "litellm_team"
|
||||
team_id_jwt_field: str = "client_id"
|
||||
Literal[
|
||||
"openai_routes",
|
||||
"info_routes",
|
||||
"management_routes",
|
||||
"spend_tracking_routes",
|
||||
"global_spend_tracking_routes",
|
||||
]
|
||||
] = [
|
||||
"management_routes",
|
||||
"spend_tracking_routes",
|
||||
"global_spend_tracking_routes",
|
||||
"info_routes",
|
||||
]
|
||||
team_id_jwt_field: Optional[str] = None
|
||||
team_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["openai_routes", "info_routes"]
|
||||
team_id_default: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
||||
)
|
||||
org_id_jwt_field: Optional[str] = None
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
user_id_upsert: bool = Field(
|
||||
default=False, description="If user doesn't exist, upsert them into the db."
|
||||
)
|
||||
end_user_id_jwt_field: Optional[str] = None
|
||||
public_key_ttl: float = 600
|
||||
|
||||
|
@ -678,6 +714,25 @@ class DynamoDBArgs(LiteLLMBase):
|
|||
assume_role_aws_session_name: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigFieldUpdate(LiteLLMBase):
|
||||
field_name: str
|
||||
field_value: Any
|
||||
config_type: Literal["general_settings"]
|
||||
|
||||
|
||||
class ConfigFieldDelete(LiteLLMBase):
|
||||
config_type: Literal["general_settings"]
|
||||
field_name: str
|
||||
|
||||
|
||||
class ConfigList(LiteLLMBase):
|
||||
field_name: str
|
||||
field_type: str
|
||||
field_description: str
|
||||
field_value: Any
|
||||
stored_in_db: Optional[bool]
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
|
@ -725,7 +780,11 @@ class ConfigGeneralSettings(LiteLLMBase):
|
|||
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
|
||||
)
|
||||
max_parallel_requests: Optional[int] = Field(
|
||||
None, description="maximum parallel requests for each api key"
|
||||
None,
|
||||
description="maximum parallel requests for each api key",
|
||||
)
|
||||
global_max_parallel_requests: Optional[int] = Field(
|
||||
None, description="global max parallel requests to allow for a proxy instance."
|
||||
)
|
||||
infer_model_from_keys: Optional[bool] = Field(
|
||||
None,
|
||||
|
@ -954,3 +1013,16 @@ class LiteLLM_ErrorLogs(LiteLLMBase):
|
|||
|
||||
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
|
||||
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None
|
||||
|
||||
|
||||
class TokenCountRequest(LiteLLMBase):
|
||||
model: str
|
||||
prompt: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class TokenCountResponse(LiteLLMBase):
|
||||
total_tokens: int
|
||||
request_model: str
|
||||
model_used: str
|
||||
tokenizer_type: str
|
||||
|
|
|
@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
|||
|
||||
def common_checks(
|
||||
request_body: dict,
|
||||
team_object: LiteLLM_TeamTable,
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
global_proxy_spend: Optional[float],
|
||||
|
@ -45,13 +45,14 @@ def common_checks(
|
|||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
"""
|
||||
_model = request_body.get("model", None)
|
||||
if team_object.blocked == True:
|
||||
if team_object is not None and team_object.blocked == True:
|
||||
raise Exception(
|
||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||
)
|
||||
# 2. If user can call model
|
||||
if (
|
||||
_model is not None
|
||||
and team_object is not None
|
||||
and len(team_object.models) > 0
|
||||
and _model not in team_object.models
|
||||
):
|
||||
|
@ -65,7 +66,8 @@ def common_checks(
|
|||
)
|
||||
# 3. If team is in budget
|
||||
if (
|
||||
team_object.max_budget is not None
|
||||
team_object is not None
|
||||
and team_object.max_budget is not None
|
||||
and team_object.spend is not None
|
||||
and team_object.spend > team_object.max_budget
|
||||
):
|
||||
|
@ -239,6 +241,7 @@ async def get_user_object(
|
|||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
user_id_upsert: bool,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
|
@ -252,7 +255,7 @@ async def get_user_object(
|
|||
return None
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
|
@ -260,16 +263,27 @@ async def get_user_object(
|
|||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
if user_id_upsert:
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
data={"user_id": user_id}
|
||||
)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e: # if end-user not in db
|
||||
raise Exception(
|
||||
_response = LiteLLM_UserTable(**dict(response))
|
||||
|
||||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
||||
|
||||
return _response
|
||||
except Exception as e: # if user not in db
|
||||
raise ValueError(
|
||||
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
|
@ -290,7 +304,7 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
# check if in cache
|
||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return LiteLLM_TeamTable(**cached_team_obj)
|
||||
|
@ -305,7 +319,11 @@ async def get_team_object(
|
|||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_TeamTable(**response.dict())
|
||||
_response = LiteLLM_TeamTable(**response.dict())
|
||||
# save the team object to cache
|
||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
||||
|
||||
return _response
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
|
|
|
@ -55,12 +55,9 @@ class JWTHandler:
|
|||
return True
|
||||
return False
|
||||
|
||||
def is_team(self, scopes: list) -> bool:
|
||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||
def get_end_user_id(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
|
@ -70,13 +67,36 @@ class JWTHandler:
|
|||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def is_required_team_id(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'team_id_jwt_field' is set
|
||||
- False: if not
|
||||
"""
|
||||
if self.litellm_jwtauth.team_id_jwt_field is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
elif self.litellm_jwtauth.team_id_default is not None:
|
||||
team_id = self.litellm_jwtauth.team_id_default
|
||||
else:
|
||||
team_id = None
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
def is_upsert_user_id(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_id_upsert' is set
|
||||
- False: if not
|
||||
"""
|
||||
return self.litellm_jwtauth.user_id_upsert
|
||||
|
||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||
|
@ -207,12 +227,14 @@ class JWTHandler:
|
|||
raise Exception(f"Validation fails: {str(e)}")
|
||||
elif public_key is not None and isinstance(public_key, str):
|
||||
try:
|
||||
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
|
||||
cert = x509.load_pem_x509_certificate(
|
||||
public_key.encode(), default_backend()
|
||||
)
|
||||
|
||||
# Extract public key
|
||||
key = cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
# decode the token using the public key
|
||||
|
@ -221,7 +243,7 @@ class JWTHandler:
|
|||
key,
|
||||
algorithms=algorithms,
|
||||
audience=audience,
|
||||
options=decode_options
|
||||
options=decode_options,
|
||||
)
|
||||
return payload
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
|
|
147
litellm/proxy/hooks/azure_content_safety.py
Normal file
147
litellm/proxy/hooks/azure_content_safety.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
import litellm, traceback, sys, uuid
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class _PROXY_AzureContentSafety(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self, endpoint, api_key, thresholds=None):
|
||||
try:
|
||||
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.ai.contentsafety.models import (
|
||||
TextCategory,
|
||||
AnalyzeTextOptions,
|
||||
AnalyzeTextOutputType,
|
||||
)
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.text_category = TextCategory
|
||||
self.analyze_text_options = AnalyzeTextOptions
|
||||
self.analyze_text_output_type = AnalyzeTextOutputType
|
||||
self.azure_http_error = HttpResponseError
|
||||
|
||||
self.thresholds = self._configure_thresholds(thresholds)
|
||||
|
||||
self.client = ContentSafetyClient(
|
||||
self.endpoint, AzureKeyCredential(self.api_key)
|
||||
)
|
||||
|
||||
def _configure_thresholds(self, thresholds=None):
|
||||
default_thresholds = {
|
||||
self.text_category.HATE: 4,
|
||||
self.text_category.SELF_HARM: 4,
|
||||
self.text_category.SEXUAL: 4,
|
||||
self.text_category.VIOLENCE: 4,
|
||||
}
|
||||
|
||||
if thresholds is None:
|
||||
return default_thresholds
|
||||
|
||||
for key, default in default_thresholds.items():
|
||||
if key not in thresholds:
|
||||
thresholds[key] = default
|
||||
|
||||
return thresholds
|
||||
|
||||
def _compute_result(self, response):
|
||||
result = {}
|
||||
|
||||
category_severity = {
|
||||
item.category: item.severity for item in response.categories_analysis
|
||||
}
|
||||
for category in self.text_category:
|
||||
severity = category_severity.get(category)
|
||||
if severity is not None:
|
||||
result[category] = {
|
||||
"filtered": severity >= self.thresholds[category],
|
||||
"severity": severity,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def test_violation(self, content: str, source: Optional[str] = None):
|
||||
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
||||
|
||||
# Construct a request
|
||||
request = self.analyze_text_options(
|
||||
text=content,
|
||||
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
||||
)
|
||||
|
||||
# Analyze text
|
||||
try:
|
||||
response = await self.client.analyze_text(request)
|
||||
except self.azure_http_error as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
result = self._compute_result(response)
|
||||
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
||||
|
||||
for key, value in result.items():
|
||||
if value["filtered"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated content safety policy",
|
||||
"source": source,
|
||||
"category": key,
|
||||
"severity": value["severity"],
|
||||
},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
||||
try:
|
||||
if call_type == "completion" and "messages" in data:
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
await self.test_violation(content=m["content"], source="input")
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
await self.test_violation(
|
||||
content=response.choices[0].message.content, source="output"
|
||||
)
|
||||
|
||||
# async def async_post_call_streaming_hook(
|
||||
# self,
|
||||
# user_api_key_dict: UserAPIKeyAuth,
|
||||
# response: str,
|
||||
# ):
|
||||
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
||||
# await self.test_violation(content=response, source="output")
|
|
@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
tpm_limit = sys.maxsize
|
||||
|
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
# check if below limit
|
||||
if current_global_requests is None:
|
||||
current_global_requests = 1
|
||||
# if above -> raise error
|
||||
if current_global_requests >= global_max_parallel_requests:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await cache.async_increment_cache(key=_key, value=1, local_only=True)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
|
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = (
|
||||
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
||||
)
|
||||
|
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if (
|
||||
hasattr(kwargs["exception"], "status_code")
|
||||
and kwargs["exception"].status_code == 429
|
||||
and "Max parallel request limit reached" in str(kwargs["exception"])
|
||||
):
|
||||
if "Max parallel request limit reached" in str(kwargs["exception"]):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = (
|
||||
await self.user_api_key_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
|
|
@ -11,7 +11,9 @@ sys.path.append(os.getcwd())
|
|||
|
||||
config_filename = "litellm.secrets"
|
||||
|
||||
load_dotenv()
|
||||
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
|
||||
if litellm_mode == "DEV":
|
||||
load_dotenv()
|
||||
from importlib import resources
|
||||
import shutil
|
||||
|
||||
|
|
|
@ -4,11 +4,20 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: llama3
|
||||
litellm_params:
|
||||
model: groq/llama3-8b-8192
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/any"
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
|
@ -17,4 +26,10 @@ general_settings:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]
|
||||
failure_callback: ["langfuse"]
|
||||
default_team_settings:
|
||||
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
|
||||
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -140,6 +140,8 @@ class ProxyLogging:
|
|||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.utils._init_custom_logger_compatible_class(callback)
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
|
@ -252,8 +254,8 @@ class ProxyLogging:
|
|||
"""
|
||||
Runs the CustomLogger's async_moderation_hook()
|
||||
"""
|
||||
for callback in litellm.callbacks:
|
||||
new_data = copy.deepcopy(data)
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_moderation_hook(
|
||||
|
@ -418,9 +420,14 @@ class ProxyLogging:
|
|||
|
||||
Related issue - https://github.com/BerriAI/litellm/issues/3395
|
||||
"""
|
||||
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
|
||||
exception_str = str(original_exception)
|
||||
if litellm_debug_info is not None:
|
||||
exception_str += litellm_debug_info
|
||||
|
||||
asyncio.create_task(
|
||||
self.alerting_handler(
|
||||
message=f"LLM API call failed: {str(original_exception)}",
|
||||
message=f"LLM API call failed: {exception_str}",
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
request_data=request_data,
|
||||
|
@ -1787,7 +1794,9 @@ def hash_token(token: str):
|
|||
return hashed_token
|
||||
|
||||
|
||||
def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||
def get_logging_payload(
|
||||
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
||||
):
|
||||
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||
from pydantic import Json
|
||||
import uuid
|
||||
|
@ -1865,7 +1874,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"request_tags": metadata.get("tags", []),
|
||||
"end_user": kwargs.get("user", ""),
|
||||
"end_user": end_user_id or "",
|
||||
"api_base": litellm_params.get("api_base", ""),
|
||||
}
|
||||
|
||||
|
@ -2028,6 +2037,11 @@ async def update_spend(
|
|||
raise e
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"End-User Spend transactions: {}".format(
|
||||
len(prisma_client.end_user_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
|
@ -2043,13 +2057,18 @@ async def update_spend(
|
|||
max_end_user_budget = None
|
||||
if litellm.max_end_user_budget is not None:
|
||||
max_end_user_budget = litellm.max_end_user_budget
|
||||
new_user_obj = LiteLLM_EndUserTable(
|
||||
user_id=end_user_id, spend=response_cost, blocked=False
|
||||
)
|
||||
batcher.litellm_endusertable.update_many(
|
||||
batcher.litellm_endusertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": end_user_id,
|
||||
"spend": response_cost,
|
||||
"blocked": False,
|
||||
},
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
|
|
|
@ -9,7 +9,8 @@
|
|||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
|
@ -46,8 +47,10 @@ from litellm.types.router import (
|
|||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -60,7 +63,7 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[list] = None,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -81,6 +84,9 @@ class Router:
|
|||
default_max_parallel_requests: Optional[int] = None,
|
||||
set_verbose: bool = False,
|
||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||
default_fallbacks: Optional[
|
||||
List[str]
|
||||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
|
@ -256,7 +262,22 @@ class Router:
|
|||
|
||||
self.retry_after = retry_after
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
|
||||
## SETTING FALLBACKS ##
|
||||
### validate if it's set + in correct format
|
||||
_fallbacks = fallbacks or litellm.fallbacks
|
||||
|
||||
self.validate_fallbacks(fallback_param=_fallbacks)
|
||||
### set fallbacks
|
||||
self.fallbacks = _fallbacks
|
||||
|
||||
if default_fallbacks is not None or litellm.default_fallbacks is not None:
|
||||
_fallbacks = default_fallbacks or litellm.default_fallbacks
|
||||
if self.fallbacks is not None:
|
||||
self.fallbacks.append({"*": _fallbacks})
|
||||
else:
|
||||
self.fallbacks = [{"*": _fallbacks}]
|
||||
|
||||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
@ -324,6 +345,21 @@ class Router:
|
|||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||
if fallback_param is None:
|
||||
return
|
||||
if len(fallback_param) > 0: # if set
|
||||
## for dictionary in list, check if only 1 key in dict
|
||||
for _dict in fallback_param:
|
||||
assert isinstance(_dict, dict), "Item={}, not a dictionary".format(
|
||||
_dict
|
||||
)
|
||||
assert (
|
||||
len(_dict.keys()) == 1
|
||||
), "Only 1 key allows in dictionary. You set={} for dict={}".format(
|
||||
len(_dict.keys()), _dict
|
||||
)
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||
|
@ -468,12 +504,30 @@ class Router:
|
|||
)
|
||||
raise e
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
# The actual implementation of the function
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["stream"] = stream
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
|
||||
|
@ -605,6 +659,33 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -1385,7 +1466,7 @@ class Router:
|
|||
verbose_router_logger.debug(f"Trying to fallback b/w models")
|
||||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
@ -1416,18 +1497,29 @@ class Router:
|
|||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
for item in fallbacks:
|
||||
key_list = list(item.keys())
|
||||
if len(key_list) == 0:
|
||||
continue
|
||||
if key_list[0] == model_group:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
verbose_router_logger.info(
|
||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
|
@ -1450,6 +1542,9 @@ class Router:
|
|||
response = await self.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -1479,22 +1574,30 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
) or (
|
||||
isinstance(original_exception, openai.RateLimitError)
|
||||
and fallbacks is not None
|
||||
):
|
||||
raise original_exception
|
||||
### RETRY
|
||||
"""
|
||||
Retry Logic
|
||||
|
||||
_timeout = self._router_should_retry(
|
||||
"""
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model") or "",
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
|
||||
# sleeps for the length of the timeout
|
||||
await asyncio.sleep(_timeout)
|
||||
|
||||
if (
|
||||
|
@ -1528,10 +1631,14 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._router_should_retry(
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
await asyncio.sleep(_timeout)
|
||||
try:
|
||||
|
@ -1540,17 +1647,57 @@ class Router:
|
|||
pass
|
||||
raise original_exception
|
||||
|
||||
def should_retry_this_error(
|
||||
self,
|
||||
error: Exception,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
context_window_fallbacks: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||
|
||||
2. raise an exception for RateLimitError if
|
||||
- there are no fallbacks
|
||||
- there are no healthy deployments in the same model group
|
||||
"""
|
||||
|
||||
_num_healthy_deployments = 0
|
||||
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||
_num_healthy_deployments = len(healthy_deployments)
|
||||
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(error, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is None
|
||||
):
|
||||
raise error
|
||||
|
||||
# Error we should only retry if there are other deployments
|
||||
if isinstance(error, openai.RateLimitError) or isinstance(
|
||||
error, openai.AuthenticationError
|
||||
):
|
||||
if _num_healthy_deployments <= 0:
|
||||
raise error
|
||||
|
||||
return True
|
||||
|
||||
def function_with_fallbacks(self, *args, **kwargs):
|
||||
"""
|
||||
Try calling the function_with_retries
|
||||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
|
||||
)
|
||||
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1559,7 +1706,7 @@ class Router:
|
|||
try:
|
||||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
@ -1601,10 +1748,20 @@ class Router:
|
|||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
fallback_model_group = None
|
||||
for item in fallbacks:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
@ -1628,12 +1785,27 @@ class Router:
|
|||
raise e
|
||||
raise original_exception
|
||||
|
||||
def _router_should_retry(
|
||||
self, e: Exception, remaining_retries: int, num_retries: int
|
||||
def _time_to_sleep_before_retry(
|
||||
self,
|
||||
e: Exception,
|
||||
remaining_retries: int,
|
||||
num_retries: int,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
) -> Union[int, float]:
|
||||
"""
|
||||
Calculate back-off, then retry
|
||||
|
||||
It should instantly retry only when:
|
||||
1. there are healthy deployments in the same model group
|
||||
2. there are fallbacks for the completion call
|
||||
"""
|
||||
if (
|
||||
healthy_deployments is not None
|
||||
and isinstance(healthy_deployments, list)
|
||||
and len(healthy_deployments) > 0
|
||||
):
|
||||
return 0
|
||||
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(
|
||||
remaining_retries=remaining_retries,
|
||||
|
@ -1670,23 +1842,29 @@ class Router:
|
|||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
) or (
|
||||
isinstance(original_exception, openai.RateLimitError)
|
||||
and fallbacks is not None
|
||||
):
|
||||
raise original_exception
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
### RETRY
|
||||
_timeout = self._router_should_retry(
|
||||
_healthy_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
|
||||
time.sleep(_timeout)
|
||||
for current_attempt in range(num_retries):
|
||||
verbose_router_logger.debug(
|
||||
|
@ -1700,11 +1878,15 @@ class Router:
|
|||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_healthy_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._router_should_retry(
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=e,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
raise original_exception
|
||||
|
@ -1804,6 +1986,45 @@ class Router:
|
|||
key=rpm_key, value=request_count, local_only=True
|
||||
) # don't change existing ttl
|
||||
|
||||
def _is_cooldown_required(self, exception_status: Union[str, int]):
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
try:
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
exception_status = int(exception_status)
|
||||
|
||||
if exception_status >= 400 and exception_status < 500:
|
||||
if exception_status == 429:
|
||||
# Cool down 429 Rate Limit Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 401:
|
||||
# Cool down 401 Auth Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 408:
|
||||
return True
|
||||
|
||||
else:
|
||||
# Do NOT cool down all other 4XX Errors
|
||||
return False
|
||||
|
||||
else:
|
||||
# should cool down for all other errors
|
||||
return True
|
||||
|
||||
except:
|
||||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
def _set_cooldown_deployments(
|
||||
self, exception_status: Union[str, int], deployment: Optional[str] = None
|
||||
):
|
||||
|
@ -1817,6 +2038,9 @@ class Router:
|
|||
if deployment is None:
|
||||
return
|
||||
|
||||
if self._is_cooldown_required(exception_status=exception_status) == False:
|
||||
return
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get current fails for deployment
|
||||
|
@ -1907,6 +2131,47 @@ class Router:
|
|||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def _get_healthy_deployments(self, model: str):
|
||||
_all_deployments: list = []
|
||||
try:
|
||||
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||
model=model,
|
||||
)
|
||||
if type(_all_deployments) == dict:
|
||||
return []
|
||||
except:
|
||||
pass
|
||||
|
||||
unhealthy_deployments = self._get_cooldown_deployments()
|
||||
healthy_deployments: list = []
|
||||
for deployment in _all_deployments:
|
||||
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||
continue
|
||||
else:
|
||||
healthy_deployments.append(deployment)
|
||||
|
||||
return healthy_deployments
|
||||
|
||||
async def _async_get_healthy_deployments(self, model: str):
|
||||
_all_deployments: list = []
|
||||
try:
|
||||
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||
model=model,
|
||||
)
|
||||
if type(_all_deployments) == dict:
|
||||
return []
|
||||
except:
|
||||
pass
|
||||
|
||||
unhealthy_deployments = await self._async_get_cooldown_deployments()
|
||||
healthy_deployments: list = []
|
||||
for deployment in _all_deployments:
|
||||
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||
continue
|
||||
else:
|
||||
healthy_deployments.append(deployment)
|
||||
return healthy_deployments
|
||||
|
||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
Mimics 'async_routing_strategy_pre_call_checks'
|
||||
|
@ -2115,6 +2380,10 @@ class Router:
|
|||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
)
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
|
||||
|
@ -2126,6 +2395,7 @@ class Router:
|
|||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2150,6 +2420,7 @@ class Router:
|
|||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2174,6 +2445,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2198,6 +2470,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2230,6 +2503,7 @@ class Router:
|
|||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
}
|
||||
from litellm.llms.azure import select_azure_base_url_or_endpoint
|
||||
|
||||
|
@ -2329,7 +2603,7 @@ class Router:
|
|||
) # cache for 1 hr
|
||||
|
||||
else:
|
||||
_api_key = api_key
|
||||
_api_key = api_key # type: ignore
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
|
@ -2557,16 +2831,25 @@ class Router:
|
|||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
# set region (if azure model)
|
||||
# set region (if azure model) ## PREVIEW FEATURE ##
|
||||
if litellm.enable_preview_features == True:
|
||||
print("Auto inferring region") # noqa
|
||||
"""
|
||||
Hiding behind a feature flag
|
||||
When there is a large amount of LLM deployments this makes startup times blow up
|
||||
"""
|
||||
try:
|
||||
if "azure" in deployment.litellm_params.model:
|
||||
if (
|
||||
"azure" in deployment.litellm_params.model
|
||||
and deployment.litellm_params.region_name is None
|
||||
):
|
||||
region = litellm.utils.get_model_region(
|
||||
litellm_params=deployment.litellm_params, mode=None
|
||||
)
|
||||
|
||||
deployment.litellm_params.region_name = region
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
verbose_router_logger.debug(
|
||||
"Unable to get the region for azure model - {}, {}".format(
|
||||
deployment.litellm_params.model, str(e)
|
||||
)
|
||||
|
@ -2600,7 +2883,7 @@ class Router:
|
|||
self.model_names.append(deployment.model_name)
|
||||
return deployment
|
||||
|
||||
def upsert_deployment(self, deployment: Deployment) -> Deployment:
|
||||
def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||
"""
|
||||
Add or update deployment
|
||||
Parameters:
|
||||
|
@ -2610,8 +2893,17 @@ class Router:
|
|||
- The added/updated deployment
|
||||
"""
|
||||
# check if deployment already exists
|
||||
_deployment_model_id = deployment.model_info.id or ""
|
||||
_deployment_on_router: Optional[Deployment] = self.get_deployment(
|
||||
model_id=_deployment_model_id
|
||||
)
|
||||
if _deployment_on_router is not None:
|
||||
# deployment with this model_id exists on the router
|
||||
if deployment.litellm_params == _deployment_on_router.litellm_params:
|
||||
# No need to update
|
||||
return None
|
||||
|
||||
if deployment.model_info.id in self.get_model_ids():
|
||||
# if there is a new litellm param -> then update the deployment
|
||||
# remove the previous deployment
|
||||
removal_idx: Optional[int] = None
|
||||
for idx, model in enumerate(self.model_list):
|
||||
|
@ -2620,16 +2912,9 @@ class Router:
|
|||
|
||||
if removal_idx is not None:
|
||||
self.model_list.pop(removal_idx)
|
||||
|
||||
# add to model list
|
||||
_deployment = deployment.to_json(exclude_none=True)
|
||||
self.model_list.append(_deployment)
|
||||
|
||||
# initialize client
|
||||
self._add_deployment(deployment=deployment)
|
||||
|
||||
# add to model names
|
||||
self.model_names.append(deployment.model_name)
|
||||
else:
|
||||
# if the model_id is not in router
|
||||
self.add_deployment(deployment=deployment)
|
||||
return deployment
|
||||
|
||||
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
||||
|
@ -2942,7 +3227,7 @@ class Router:
|
|||
):
|
||||
# check if in allowed_model_region
|
||||
if (
|
||||
_is_region_eu(model_region=_litellm_params["region_name"])
|
||||
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
|
||||
== False
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
|
@ -2966,7 +3251,7 @@ class Router:
|
|||
|
||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
|
||||
)
|
||||
elif _context_window_error == True:
|
||||
raise litellm.ContextWindowExceededError(
|
||||
|
@ -2990,11 +3275,15 @@ class Router:
|
|||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
):
|
||||
) -> Tuple[str, Union[list, dict]]:
|
||||
"""
|
||||
Common checks for 'get_available_deployment' across sync + async call.
|
||||
|
||||
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||
|
||||
Returns
|
||||
- Dict, if specific model chosen
|
||||
- List, if multiple models chosen
|
||||
"""
|
||||
# check if aliases set on litellm model alias map
|
||||
if specific_deployment == True:
|
||||
|
@ -3004,7 +3293,7 @@ class Router:
|
|||
if deployment_model == model:
|
||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment, None
|
||||
return deployment_model, deployment
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
||||
)
|
||||
|
@ -3020,7 +3309,7 @@ class Router:
|
|||
self.default_deployment
|
||||
) # self.default_deployment
|
||||
updated_deployment["litellm_params"]["model"] = model
|
||||
return updated_deployment, None
|
||||
return model, updated_deployment
|
||||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
|
@ -3034,7 +3323,9 @@ class Router:
|
|||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(f"No healthy deployment available, passed model={model}. ")
|
||||
raise ValueError(
|
||||
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
|
||||
)
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
|
@ -3073,10 +3364,10 @@ class Router:
|
|||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
if healthy_deployments is None:
|
||||
return model
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
|
||||
# filter out the deployments currently cooling down
|
||||
deployments_to_remove = []
|
||||
|
@ -3095,13 +3386,12 @@ class Router:
|
|||
healthy_deployments.remove(deployment)
|
||||
|
||||
# filter pre-call checks
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
_allowed_model_region = (
|
||||
request_kwargs.get("allowed_model_region")
|
||||
if request_kwargs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
if _allowed_model_region == "eu":
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model,
|
||||
|
@ -3122,8 +3412,10 @@ class Router:
|
|||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
if _allowed_model_region is None:
|
||||
_allowed_model_region = "n/a"
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -3132,7 +3424,7 @@ class Router:
|
|||
):
|
||||
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
@ -3142,7 +3434,7 @@ class Router:
|
|||
):
|
||||
deployment = await self.lowestcost_logger.async_get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
@ -3191,7 +3483,7 @@ class Router:
|
|||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
@ -3220,8 +3512,8 @@ class Router:
|
|||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
if healthy_deployments is None:
|
||||
return model
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
|
||||
# filter out the deployments currently cooling down
|
||||
deployments_to_remove = []
|
||||
|
@ -3245,7 +3537,7 @@ class Router:
|
|||
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
model_group=model, healthy_deployments=healthy_deployments # type: ignore
|
||||
)
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
|
@ -3293,7 +3585,7 @@ class Router:
|
|||
):
|
||||
deployment = self.lowestlatency_logger.get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
elif (
|
||||
|
@ -3302,7 +3594,7 @@ class Router:
|
|||
):
|
||||
deployment = self.lowesttpm_logger.get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
@ -3312,7 +3604,7 @@ class Router:
|
|||
):
|
||||
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
@ -3321,7 +3613,7 @@ class Router:
|
|||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
@ -3483,7 +3775,7 @@ class Router:
|
|||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_alert(
|
||||
message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}",
|
||||
message=f"Router: Cooling down Deployment:\nModel Name: {_model_name}\nAPI Base: {_api_base}\n{self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
|
||||
alert_type="cooldown_deployment",
|
||||
level="Low",
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue